modeling_layoutlmv3.py 59 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384
  1. # coding=utf-8
  2. # Copyright 2022 Microsoft Research and The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch LayoutLMv3 model."""
  16. import collections
  17. import math
  18. from typing import Optional, Tuple, Union
  19. import torch
  20. import torch.nn as nn
  21. import torch.nn.functional as F
  22. import torch.utils.checkpoint
  23. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  24. from ...activations import ACT2FN
  25. from ...modeling_outputs import (
  26. BaseModelOutput,
  27. QuestionAnsweringModelOutput,
  28. SequenceClassifierOutput,
  29. TokenClassifierOutput,
  30. )
  31. from ...modeling_utils import PreTrainedModel
  32. from ...pytorch_utils import apply_chunking_to_forward
  33. from ...utils import (
  34. add_start_docstrings,
  35. add_start_docstrings_to_model_forward,
  36. logging,
  37. replace_return_docstrings,
  38. torch_int,
  39. )
  40. from .configuration_layoutlmv3 import LayoutLMv3Config
  41. logger = logging.get_logger(__name__)
  42. _CONFIG_FOR_DOC = "LayoutLMv3Config"
  43. LAYOUTLMV3_START_DOCSTRING = r"""
  44. This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
  45. it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
  46. behavior.
  47. Parameters:
  48. config ([`LayoutLMv3Config`]): Model configuration class with all the parameters of the model.
  49. Initializing with a config file does not load the weights associated with the model, only the
  50. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  51. """
  52. LAYOUTLMV3_MODEL_INPUTS_DOCSTRING = r"""
  53. Args:
  54. input_ids (`torch.LongTensor` of shape `({0})`):
  55. Indices of input sequence tokens in the vocabulary.
  56. Note that `sequence_length = token_sequence_length + patch_sequence_length + 1` where `1` is for [CLS]
  57. token. See `pixel_values` for `patch_sequence_length`.
  58. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  59. [`PreTrainedTokenizer.__call__`] for details.
  60. [What are input IDs?](../glossary#input-ids)
  61. bbox (`torch.LongTensor` of shape `({0}, 4)`, *optional*):
  62. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  63. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  64. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  65. y1) represents the position of the lower right corner.
  66. Note that `sequence_length = token_sequence_length + patch_sequence_length + 1` where `1` is for [CLS]
  67. token. See `pixel_values` for `patch_sequence_length`.
  68. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  69. Batch of document images. Each image is divided into patches of shape `(num_channels, config.patch_size,
  70. config.patch_size)` and the total number of patches (=`patch_sequence_length`) equals to `((height /
  71. config.patch_size) * (width / config.patch_size))`.
  72. attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
  73. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  74. - 1 for tokens that are **not masked**,
  75. - 0 for tokens that are **masked**.
  76. Note that `sequence_length = token_sequence_length + patch_sequence_length + 1` where `1` is for [CLS]
  77. token. See `pixel_values` for `patch_sequence_length`.
  78. [What are attention masks?](../glossary#attention-mask)
  79. token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
  80. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  81. 1]`:
  82. - 0 corresponds to a *sentence A* token,
  83. - 1 corresponds to a *sentence B* token.
  84. Note that `sequence_length = token_sequence_length + patch_sequence_length + 1` where `1` is for [CLS]
  85. token. See `pixel_values` for `patch_sequence_length`.
  86. [What are token type IDs?](../glossary#token-type-ids)
  87. position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
  88. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  89. config.max_position_embeddings - 1]`.
  90. Note that `sequence_length = token_sequence_length + patch_sequence_length + 1` where `1` is for [CLS]
  91. token. See `pixel_values` for `patch_sequence_length`.
  92. [What are position IDs?](../glossary#position-ids)
  93. head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  94. Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
  95. - 1 indicates the head is **not masked**,
  96. - 0 indicates the head is **masked**.
  97. inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
  98. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  99. is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
  100. model's internal embedding lookup matrix.
  101. output_attentions (`bool`, *optional*):
  102. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  103. tensors for more detail.
  104. output_hidden_states (`bool`, *optional*):
  105. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  106. more detail.
  107. return_dict (`bool`, *optional*):
  108. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  109. """
  110. LAYOUTLMV3_DOWNSTREAM_INPUTS_DOCSTRING = r"""
  111. Args:
  112. input_ids (`torch.LongTensor` of shape `({0})`):
  113. Indices of input sequence tokens in the vocabulary.
  114. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  115. [`PreTrainedTokenizer.__call__`] for details.
  116. [What are input IDs?](../glossary#input-ids)
  117. bbox (`torch.LongTensor` of shape `({0}, 4)`, *optional*):
  118. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  119. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  120. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  121. y1) represents the position of the lower right corner.
  122. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  123. Batch of document images. Each image is divided into patches of shape `(num_channels, config.patch_size,
  124. config.patch_size)` and the total number of patches (=`patch_sequence_length`) equals to `((height /
  125. config.patch_size) * (width / config.patch_size))`.
  126. attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
  127. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  128. - 1 for tokens that are **not masked**,
  129. - 0 for tokens that are **masked**.
  130. [What are attention masks?](../glossary#attention-mask)
  131. token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
  132. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  133. 1]`:
  134. - 0 corresponds to a *sentence A* token,
  135. - 1 corresponds to a *sentence B* token.
  136. [What are token type IDs?](../glossary#token-type-ids)
  137. position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
  138. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  139. config.max_position_embeddings - 1]`.
  140. [What are position IDs?](../glossary#position-ids)
  141. head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  142. Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
  143. - 1 indicates the head is **not masked**,
  144. - 0 indicates the head is **masked**.
  145. inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
  146. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  147. is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
  148. model's internal embedding lookup matrix.
  149. output_attentions (`bool`, *optional*):
  150. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  151. tensors for more detail.
  152. output_hidden_states (`bool`, *optional*):
  153. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  154. more detail.
  155. return_dict (`bool`, *optional*):
  156. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  157. """
  158. class LayoutLMv3PatchEmbeddings(nn.Module):
  159. """LayoutLMv3 image (patch) embeddings. This class also automatically interpolates the position embeddings for varying
  160. image sizes."""
  161. def __init__(self, config):
  162. super().__init__()
  163. image_size = (
  164. config.input_size
  165. if isinstance(config.input_size, collections.abc.Iterable)
  166. else (config.input_size, config.input_size)
  167. )
  168. patch_size = (
  169. config.patch_size
  170. if isinstance(config.patch_size, collections.abc.Iterable)
  171. else (config.patch_size, config.patch_size)
  172. )
  173. self.patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
  174. self.proj = nn.Conv2d(config.num_channels, config.hidden_size, kernel_size=patch_size, stride=patch_size)
  175. def forward(self, pixel_values, position_embedding=None):
  176. embeddings = self.proj(pixel_values)
  177. if position_embedding is not None:
  178. # interpolate the position embedding to the corresponding size
  179. position_embedding = position_embedding.view(1, self.patch_shape[0], self.patch_shape[1], -1)
  180. position_embedding = position_embedding.permute(0, 3, 1, 2)
  181. patch_height, patch_width = embeddings.shape[2], embeddings.shape[3]
  182. position_embedding = F.interpolate(position_embedding, size=(patch_height, patch_width), mode="bicubic")
  183. embeddings = embeddings + position_embedding
  184. embeddings = embeddings.flatten(2).transpose(1, 2)
  185. return embeddings
  186. class LayoutLMv3TextEmbeddings(nn.Module):
  187. """
  188. LayoutLMv3 text embeddings. Same as `RobertaEmbeddings` but with added spatial (layout) embeddings.
  189. """
  190. def __init__(self, config):
  191. super().__init__()
  192. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  193. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  194. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  195. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  196. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  197. self.register_buffer(
  198. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  199. )
  200. self.padding_idx = config.pad_token_id
  201. self.position_embeddings = nn.Embedding(
  202. config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
  203. )
  204. self.x_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.coordinate_size)
  205. self.y_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.coordinate_size)
  206. self.h_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.shape_size)
  207. self.w_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.shape_size)
  208. def calculate_spatial_position_embeddings(self, bbox):
  209. try:
  210. left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0])
  211. upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1])
  212. right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2])
  213. lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3])
  214. except IndexError as e:
  215. raise IndexError("The `bbox` coordinate values should be within 0-1000 range.") from e
  216. h_position_embeddings = self.h_position_embeddings(torch.clip(bbox[:, :, 3] - bbox[:, :, 1], 0, 1023))
  217. w_position_embeddings = self.w_position_embeddings(torch.clip(bbox[:, :, 2] - bbox[:, :, 0], 0, 1023))
  218. # below is the difference between LayoutLMEmbeddingsV2 (torch.cat) and LayoutLMEmbeddingsV1 (add)
  219. spatial_position_embeddings = torch.cat(
  220. [
  221. left_position_embeddings,
  222. upper_position_embeddings,
  223. right_position_embeddings,
  224. lower_position_embeddings,
  225. h_position_embeddings,
  226. w_position_embeddings,
  227. ],
  228. dim=-1,
  229. )
  230. return spatial_position_embeddings
  231. def create_position_ids_from_input_ids(self, input_ids, padding_idx):
  232. """
  233. Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding
  234. symbols are ignored. This is modified from fairseq's `utils.make_positions`.
  235. """
  236. # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
  237. mask = input_ids.ne(padding_idx).int()
  238. incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask)) * mask
  239. return incremental_indices.long() + padding_idx
  240. def create_position_ids_from_inputs_embeds(self, inputs_embeds):
  241. """
  242. We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
  243. """
  244. input_shape = inputs_embeds.size()[:-1]
  245. sequence_length = input_shape[1]
  246. position_ids = torch.arange(
  247. self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
  248. )
  249. return position_ids.unsqueeze(0).expand(input_shape)
  250. def forward(
  251. self,
  252. input_ids=None,
  253. bbox=None,
  254. token_type_ids=None,
  255. position_ids=None,
  256. inputs_embeds=None,
  257. ):
  258. if position_ids is None:
  259. if input_ids is not None:
  260. # Create the position ids from the input token ids. Any padded tokens remain padded.
  261. position_ids = self.create_position_ids_from_input_ids(input_ids, self.padding_idx).to(
  262. input_ids.device
  263. )
  264. else:
  265. position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
  266. if input_ids is not None:
  267. input_shape = input_ids.size()
  268. else:
  269. input_shape = inputs_embeds.size()[:-1]
  270. if token_type_ids is None:
  271. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  272. if inputs_embeds is None:
  273. inputs_embeds = self.word_embeddings(input_ids)
  274. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  275. embeddings = inputs_embeds + token_type_embeddings
  276. position_embeddings = self.position_embeddings(position_ids)
  277. embeddings += position_embeddings
  278. spatial_position_embeddings = self.calculate_spatial_position_embeddings(bbox)
  279. embeddings = embeddings + spatial_position_embeddings
  280. embeddings = self.LayerNorm(embeddings)
  281. embeddings = self.dropout(embeddings)
  282. return embeddings
  283. class LayoutLMv3PreTrainedModel(PreTrainedModel):
  284. """
  285. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  286. models.
  287. """
  288. config_class = LayoutLMv3Config
  289. base_model_prefix = "layoutlmv3"
  290. def _init_weights(self, module):
  291. """Initialize the weights"""
  292. if isinstance(module, (nn.Linear, nn.Conv2d)):
  293. # Slightly different from the TF version which uses truncated_normal for initialization
  294. # cf https://github.com/pytorch/pytorch/pull/5617
  295. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  296. if module.bias is not None:
  297. module.bias.data.zero_()
  298. elif isinstance(module, nn.Embedding):
  299. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  300. if module.padding_idx is not None:
  301. module.weight.data[module.padding_idx].zero_()
  302. elif isinstance(module, nn.LayerNorm):
  303. module.bias.data.zero_()
  304. module.weight.data.fill_(1.0)
  305. class LayoutLMv3SelfAttention(nn.Module):
  306. def __init__(self, config):
  307. super().__init__()
  308. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  309. raise ValueError(
  310. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  311. f"heads ({config.num_attention_heads})"
  312. )
  313. self.num_attention_heads = config.num_attention_heads
  314. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  315. self.all_head_size = self.num_attention_heads * self.attention_head_size
  316. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  317. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  318. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  319. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  320. self.has_relative_attention_bias = config.has_relative_attention_bias
  321. self.has_spatial_attention_bias = config.has_spatial_attention_bias
  322. def transpose_for_scores(self, x):
  323. new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
  324. x = x.view(*new_x_shape)
  325. return x.permute(0, 2, 1, 3)
  326. def cogview_attention(self, attention_scores, alpha=32):
  327. """
  328. https://arxiv.org/abs/2105.13290 Section 2.4 Stabilization of training: Precision Bottleneck Relaxation
  329. (PB-Relax). A replacement of the original nn.Softmax(dim=-1)(attention_scores). Seems the new attention_probs
  330. will result in a slower speed and a little bias. Can use torch.allclose(standard_attention_probs,
  331. cogview_attention_probs, atol=1e-08) for comparison. The smaller atol (e.g., 1e-08), the better.
  332. """
  333. scaled_attention_scores = attention_scores / alpha
  334. max_value = scaled_attention_scores.amax(dim=(-1)).unsqueeze(-1)
  335. new_attention_scores = (scaled_attention_scores - max_value) * alpha
  336. return nn.Softmax(dim=-1)(new_attention_scores)
  337. def forward(
  338. self,
  339. hidden_states,
  340. attention_mask=None,
  341. head_mask=None,
  342. output_attentions=False,
  343. rel_pos=None,
  344. rel_2d_pos=None,
  345. ):
  346. mixed_query_layer = self.query(hidden_states)
  347. key_layer = self.transpose_for_scores(self.key(hidden_states))
  348. value_layer = self.transpose_for_scores(self.value(hidden_states))
  349. query_layer = self.transpose_for_scores(mixed_query_layer)
  350. # Take the dot product between "query" and "key" to get the raw attention scores.
  351. # The attention scores QT K/√d could be significantly larger than input elements, and result in overflow.
  352. # Changing the computational order into QT(K/√d) alleviates the problem. (https://arxiv.org/pdf/2105.13290.pdf)
  353. attention_scores = torch.matmul(query_layer / math.sqrt(self.attention_head_size), key_layer.transpose(-1, -2))
  354. if self.has_relative_attention_bias and self.has_spatial_attention_bias:
  355. attention_scores += (rel_pos + rel_2d_pos) / math.sqrt(self.attention_head_size)
  356. elif self.has_relative_attention_bias:
  357. attention_scores += rel_pos / math.sqrt(self.attention_head_size)
  358. if attention_mask is not None:
  359. # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
  360. attention_scores = attention_scores + attention_mask
  361. # Normalize the attention scores to probabilities.
  362. # Use the trick of the CogView paper to stablize training
  363. attention_probs = self.cogview_attention(attention_scores)
  364. # This is actually dropping out entire tokens to attend to, which might
  365. # seem a bit unusual, but is taken from the original Transformer paper.
  366. attention_probs = self.dropout(attention_probs)
  367. # Mask heads if we want to
  368. if head_mask is not None:
  369. attention_probs = attention_probs * head_mask
  370. context_layer = torch.matmul(attention_probs, value_layer)
  371. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  372. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  373. context_layer = context_layer.view(*new_context_layer_shape)
  374. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  375. return outputs
  376. # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfOutput
  377. class LayoutLMv3SelfOutput(nn.Module):
  378. def __init__(self, config):
  379. super().__init__()
  380. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  381. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  382. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  383. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  384. hidden_states = self.dense(hidden_states)
  385. hidden_states = self.dropout(hidden_states)
  386. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  387. return hidden_states
  388. # Copied from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Attention with LayoutLMv2->LayoutLMv3
  389. class LayoutLMv3Attention(nn.Module):
  390. def __init__(self, config):
  391. super().__init__()
  392. self.self = LayoutLMv3SelfAttention(config)
  393. self.output = LayoutLMv3SelfOutput(config)
  394. def forward(
  395. self,
  396. hidden_states,
  397. attention_mask=None,
  398. head_mask=None,
  399. output_attentions=False,
  400. rel_pos=None,
  401. rel_2d_pos=None,
  402. ):
  403. self_outputs = self.self(
  404. hidden_states,
  405. attention_mask,
  406. head_mask,
  407. output_attentions,
  408. rel_pos=rel_pos,
  409. rel_2d_pos=rel_2d_pos,
  410. )
  411. attention_output = self.output(self_outputs[0], hidden_states)
  412. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  413. return outputs
  414. # Copied from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Layer with LayoutLMv2->LayoutLMv3
  415. class LayoutLMv3Layer(nn.Module):
  416. def __init__(self, config):
  417. super().__init__()
  418. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  419. self.seq_len_dim = 1
  420. self.attention = LayoutLMv3Attention(config)
  421. self.intermediate = LayoutLMv3Intermediate(config)
  422. self.output = LayoutLMv3Output(config)
  423. def forward(
  424. self,
  425. hidden_states,
  426. attention_mask=None,
  427. head_mask=None,
  428. output_attentions=False,
  429. rel_pos=None,
  430. rel_2d_pos=None,
  431. ):
  432. self_attention_outputs = self.attention(
  433. hidden_states,
  434. attention_mask,
  435. head_mask,
  436. output_attentions=output_attentions,
  437. rel_pos=rel_pos,
  438. rel_2d_pos=rel_2d_pos,
  439. )
  440. attention_output = self_attention_outputs[0]
  441. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  442. layer_output = apply_chunking_to_forward(
  443. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  444. )
  445. outputs = (layer_output,) + outputs
  446. return outputs
  447. def feed_forward_chunk(self, attention_output):
  448. intermediate_output = self.intermediate(attention_output)
  449. layer_output = self.output(intermediate_output, attention_output)
  450. return layer_output
  451. class LayoutLMv3Encoder(nn.Module):
  452. def __init__(self, config):
  453. super().__init__()
  454. self.config = config
  455. self.layer = nn.ModuleList([LayoutLMv3Layer(config) for _ in range(config.num_hidden_layers)])
  456. self.gradient_checkpointing = False
  457. self.has_relative_attention_bias = config.has_relative_attention_bias
  458. self.has_spatial_attention_bias = config.has_spatial_attention_bias
  459. if self.has_relative_attention_bias:
  460. self.rel_pos_bins = config.rel_pos_bins
  461. self.max_rel_pos = config.max_rel_pos
  462. self.rel_pos_bias = nn.Linear(self.rel_pos_bins, config.num_attention_heads, bias=False)
  463. if self.has_spatial_attention_bias:
  464. self.max_rel_2d_pos = config.max_rel_2d_pos
  465. self.rel_2d_pos_bins = config.rel_2d_pos_bins
  466. self.rel_pos_x_bias = nn.Linear(self.rel_2d_pos_bins, config.num_attention_heads, bias=False)
  467. self.rel_pos_y_bias = nn.Linear(self.rel_2d_pos_bins, config.num_attention_heads, bias=False)
  468. def relative_position_bucket(self, relative_position, bidirectional=True, num_buckets=32, max_distance=128):
  469. ret = 0
  470. if bidirectional:
  471. num_buckets //= 2
  472. ret += (relative_position > 0).long() * num_buckets
  473. n = torch.abs(relative_position)
  474. else:
  475. n = torch.max(-relative_position, torch.zeros_like(relative_position))
  476. # now n is in the range [0, inf)
  477. # half of the buckets are for exact increments in positions
  478. max_exact = num_buckets // 2
  479. is_small = n < max_exact
  480. # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
  481. val_if_large = max_exact + (
  482. torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
  483. ).to(torch.long)
  484. val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
  485. ret += torch.where(is_small, n, val_if_large)
  486. return ret
  487. def _cal_1d_pos_emb(self, position_ids):
  488. rel_pos_mat = position_ids.unsqueeze(-2) - position_ids.unsqueeze(-1)
  489. rel_pos = self.relative_position_bucket(
  490. rel_pos_mat,
  491. num_buckets=self.rel_pos_bins,
  492. max_distance=self.max_rel_pos,
  493. )
  494. # Since this is a simple indexing operation that is independent of the input,
  495. # no need to track gradients for this operation
  496. #
  497. # Without this no_grad context, training speed slows down significantly
  498. with torch.no_grad():
  499. rel_pos = self.rel_pos_bias.weight.t()[rel_pos].permute(0, 3, 1, 2)
  500. rel_pos = rel_pos.contiguous()
  501. return rel_pos
  502. def _cal_2d_pos_emb(self, bbox):
  503. position_coord_x = bbox[:, :, 0]
  504. position_coord_y = bbox[:, :, 3]
  505. rel_pos_x_2d_mat = position_coord_x.unsqueeze(-2) - position_coord_x.unsqueeze(-1)
  506. rel_pos_y_2d_mat = position_coord_y.unsqueeze(-2) - position_coord_y.unsqueeze(-1)
  507. rel_pos_x = self.relative_position_bucket(
  508. rel_pos_x_2d_mat,
  509. num_buckets=self.rel_2d_pos_bins,
  510. max_distance=self.max_rel_2d_pos,
  511. )
  512. rel_pos_y = self.relative_position_bucket(
  513. rel_pos_y_2d_mat,
  514. num_buckets=self.rel_2d_pos_bins,
  515. max_distance=self.max_rel_2d_pos,
  516. )
  517. # Since this is a simple indexing operation that is independent of the input,
  518. # no need to track gradients for this operation
  519. #
  520. # Without this no_grad context, training speed slows down significantly
  521. with torch.no_grad():
  522. rel_pos_x = self.rel_pos_x_bias.weight.t()[rel_pos_x].permute(0, 3, 1, 2)
  523. rel_pos_y = self.rel_pos_y_bias.weight.t()[rel_pos_y].permute(0, 3, 1, 2)
  524. rel_pos_x = rel_pos_x.contiguous()
  525. rel_pos_y = rel_pos_y.contiguous()
  526. rel_2d_pos = rel_pos_x + rel_pos_y
  527. return rel_2d_pos
  528. def forward(
  529. self,
  530. hidden_states,
  531. bbox=None,
  532. attention_mask=None,
  533. head_mask=None,
  534. output_attentions=False,
  535. output_hidden_states=False,
  536. return_dict=True,
  537. position_ids=None,
  538. patch_height=None,
  539. patch_width=None,
  540. ):
  541. all_hidden_states = () if output_hidden_states else None
  542. all_self_attentions = () if output_attentions else None
  543. rel_pos = self._cal_1d_pos_emb(position_ids) if self.has_relative_attention_bias else None
  544. rel_2d_pos = self._cal_2d_pos_emb(bbox) if self.has_spatial_attention_bias else None
  545. for i, layer_module in enumerate(self.layer):
  546. if output_hidden_states:
  547. all_hidden_states = all_hidden_states + (hidden_states,)
  548. layer_head_mask = head_mask[i] if head_mask is not None else None
  549. if self.gradient_checkpointing and self.training:
  550. layer_outputs = self._gradient_checkpointing_func(
  551. layer_module.__call__,
  552. hidden_states,
  553. attention_mask,
  554. layer_head_mask,
  555. output_attentions,
  556. rel_pos,
  557. rel_2d_pos,
  558. )
  559. else:
  560. layer_outputs = layer_module(
  561. hidden_states,
  562. attention_mask,
  563. layer_head_mask,
  564. output_attentions,
  565. rel_pos=rel_pos,
  566. rel_2d_pos=rel_2d_pos,
  567. )
  568. hidden_states = layer_outputs[0]
  569. if output_attentions:
  570. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  571. if output_hidden_states:
  572. all_hidden_states = all_hidden_states + (hidden_states,)
  573. if not return_dict:
  574. return tuple(
  575. v
  576. for v in [
  577. hidden_states,
  578. all_hidden_states,
  579. all_self_attentions,
  580. ]
  581. if v is not None
  582. )
  583. return BaseModelOutput(
  584. last_hidden_state=hidden_states,
  585. hidden_states=all_hidden_states,
  586. attentions=all_self_attentions,
  587. )
  588. # Copied from transformers.models.roberta.modeling_roberta.RobertaIntermediate
  589. class LayoutLMv3Intermediate(nn.Module):
  590. def __init__(self, config):
  591. super().__init__()
  592. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  593. if isinstance(config.hidden_act, str):
  594. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  595. else:
  596. self.intermediate_act_fn = config.hidden_act
  597. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  598. hidden_states = self.dense(hidden_states)
  599. hidden_states = self.intermediate_act_fn(hidden_states)
  600. return hidden_states
  601. # Copied from transformers.models.roberta.modeling_roberta.RobertaOutput
  602. class LayoutLMv3Output(nn.Module):
  603. def __init__(self, config):
  604. super().__init__()
  605. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  606. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  607. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  608. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  609. hidden_states = self.dense(hidden_states)
  610. hidden_states = self.dropout(hidden_states)
  611. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  612. return hidden_states
  613. @add_start_docstrings(
  614. "The bare LayoutLMv3 Model transformer outputting raw hidden-states without any specific head on top.",
  615. LAYOUTLMV3_START_DOCSTRING,
  616. )
  617. class LayoutLMv3Model(LayoutLMv3PreTrainedModel):
  618. def __init__(self, config):
  619. super().__init__(config)
  620. self.config = config
  621. if config.text_embed:
  622. self.embeddings = LayoutLMv3TextEmbeddings(config)
  623. if config.visual_embed:
  624. # use the default pre-training parameters for fine-tuning (e.g., input_size)
  625. # when the input_size is larger in fine-tuning, we will interpolate the position embeddings in forward
  626. self.patch_embed = LayoutLMv3PatchEmbeddings(config)
  627. size = int(config.input_size / config.patch_size)
  628. self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  629. self.pos_embed = nn.Parameter(torch.zeros(1, size * size + 1, config.hidden_size))
  630. self.pos_drop = nn.Dropout(p=0.0)
  631. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  632. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  633. if self.config.has_relative_attention_bias or self.config.has_spatial_attention_bias:
  634. self.init_visual_bbox(image_size=(size, size))
  635. self.norm = nn.LayerNorm(config.hidden_size, eps=1e-6)
  636. self.encoder = LayoutLMv3Encoder(config)
  637. self.init_weights()
  638. def get_input_embeddings(self):
  639. return self.embeddings.word_embeddings
  640. def set_input_embeddings(self, value):
  641. self.embeddings.word_embeddings = value
  642. def _prune_heads(self, heads_to_prune):
  643. """
  644. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  645. class PreTrainedModel
  646. """
  647. for layer, heads in heads_to_prune.items():
  648. self.encoder.layer[layer].attention.prune_heads(heads)
  649. def init_visual_bbox(self, image_size=(14, 14), max_len=1000):
  650. """
  651. Create the bounding boxes for the visual (patch) tokens.
  652. """
  653. visual_bbox_x = torch.div(
  654. torch.arange(0, max_len * (image_size[1] + 1), max_len), image_size[1], rounding_mode="trunc"
  655. )
  656. visual_bbox_y = torch.div(
  657. torch.arange(0, max_len * (image_size[0] + 1), max_len), image_size[0], rounding_mode="trunc"
  658. )
  659. visual_bbox = torch.stack(
  660. [
  661. visual_bbox_x[:-1].repeat(image_size[0], 1),
  662. visual_bbox_y[:-1].repeat(image_size[1], 1).transpose(0, 1),
  663. visual_bbox_x[1:].repeat(image_size[0], 1),
  664. visual_bbox_y[1:].repeat(image_size[1], 1).transpose(0, 1),
  665. ],
  666. dim=-1,
  667. ).view(-1, 4)
  668. cls_token_box = torch.tensor([[0 + 1, 0 + 1, max_len - 1, max_len - 1]])
  669. self.visual_bbox = torch.cat([cls_token_box, visual_bbox], dim=0)
  670. def calculate_visual_bbox(self, device, dtype, batch_size):
  671. visual_bbox = self.visual_bbox.repeat(batch_size, 1, 1)
  672. visual_bbox = visual_bbox.to(device).type(dtype)
  673. return visual_bbox
  674. def forward_image(self, pixel_values):
  675. embeddings = self.patch_embed(pixel_values)
  676. # add [CLS] token
  677. batch_size, seq_len, _ = embeddings.size()
  678. cls_tokens = self.cls_token.expand(batch_size, -1, -1)
  679. embeddings = torch.cat((cls_tokens, embeddings), dim=1)
  680. # add position embeddings
  681. if self.pos_embed is not None:
  682. embeddings = embeddings + self.pos_embed
  683. embeddings = self.pos_drop(embeddings)
  684. embeddings = self.norm(embeddings)
  685. return embeddings
  686. @add_start_docstrings_to_model_forward(
  687. LAYOUTLMV3_MODEL_INPUTS_DOCSTRING.format("batch_size, token_sequence_length")
  688. )
  689. @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
  690. def forward(
  691. self,
  692. input_ids: Optional[torch.LongTensor] = None,
  693. bbox: Optional[torch.LongTensor] = None,
  694. attention_mask: Optional[torch.FloatTensor] = None,
  695. token_type_ids: Optional[torch.LongTensor] = None,
  696. position_ids: Optional[torch.LongTensor] = None,
  697. head_mask: Optional[torch.FloatTensor] = None,
  698. inputs_embeds: Optional[torch.FloatTensor] = None,
  699. pixel_values: Optional[torch.FloatTensor] = None,
  700. output_attentions: Optional[bool] = None,
  701. output_hidden_states: Optional[bool] = None,
  702. return_dict: Optional[bool] = None,
  703. ) -> Union[Tuple, BaseModelOutput]:
  704. r"""
  705. Returns:
  706. Examples:
  707. ```python
  708. >>> from transformers import AutoProcessor, AutoModel
  709. >>> from datasets import load_dataset
  710. >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
  711. >>> model = AutoModel.from_pretrained("microsoft/layoutlmv3-base")
  712. >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train", trust_remote_code=True)
  713. >>> example = dataset[0]
  714. >>> image = example["image"]
  715. >>> words = example["tokens"]
  716. >>> boxes = example["bboxes"]
  717. >>> encoding = processor(image, words, boxes=boxes, return_tensors="pt")
  718. >>> outputs = model(**encoding)
  719. >>> last_hidden_states = outputs.last_hidden_state
  720. ```"""
  721. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  722. output_hidden_states = (
  723. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  724. )
  725. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  726. if input_ids is not None:
  727. input_shape = input_ids.size()
  728. batch_size, seq_length = input_shape
  729. device = input_ids.device
  730. elif inputs_embeds is not None:
  731. input_shape = inputs_embeds.size()[:-1]
  732. batch_size, seq_length = input_shape
  733. device = inputs_embeds.device
  734. elif pixel_values is not None:
  735. batch_size = len(pixel_values)
  736. device = pixel_values.device
  737. else:
  738. raise ValueError("You have to specify either input_ids or inputs_embeds or pixel_values")
  739. if input_ids is not None or inputs_embeds is not None:
  740. if attention_mask is None:
  741. attention_mask = torch.ones(((batch_size, seq_length)), device=device)
  742. if token_type_ids is None:
  743. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  744. if bbox is None:
  745. bbox = torch.zeros(tuple(list(input_shape) + [4]), dtype=torch.long, device=device)
  746. embedding_output = self.embeddings(
  747. input_ids=input_ids,
  748. bbox=bbox,
  749. position_ids=position_ids,
  750. token_type_ids=token_type_ids,
  751. inputs_embeds=inputs_embeds,
  752. )
  753. final_bbox = final_position_ids = None
  754. patch_height = patch_width = None
  755. if pixel_values is not None:
  756. patch_height, patch_width = (
  757. torch_int(pixel_values.shape[2] / self.config.patch_size),
  758. torch_int(pixel_values.shape[3] / self.config.patch_size),
  759. )
  760. visual_embeddings = self.forward_image(pixel_values)
  761. visual_attention_mask = torch.ones(
  762. (batch_size, visual_embeddings.shape[1]), dtype=torch.long, device=device
  763. )
  764. if attention_mask is not None:
  765. attention_mask = torch.cat([attention_mask, visual_attention_mask], dim=1)
  766. else:
  767. attention_mask = visual_attention_mask
  768. if self.config.has_relative_attention_bias or self.config.has_spatial_attention_bias:
  769. if self.config.has_spatial_attention_bias:
  770. visual_bbox = self.calculate_visual_bbox(device, dtype=torch.long, batch_size=batch_size)
  771. if bbox is not None:
  772. final_bbox = torch.cat([bbox, visual_bbox], dim=1)
  773. else:
  774. final_bbox = visual_bbox
  775. visual_position_ids = torch.arange(
  776. 0, visual_embeddings.shape[1], dtype=torch.long, device=device
  777. ).repeat(batch_size, 1)
  778. if input_ids is not None or inputs_embeds is not None:
  779. position_ids = torch.arange(0, input_shape[1], device=device).unsqueeze(0)
  780. position_ids = position_ids.expand(input_shape)
  781. final_position_ids = torch.cat([position_ids, visual_position_ids], dim=1)
  782. else:
  783. final_position_ids = visual_position_ids
  784. if input_ids is not None or inputs_embeds is not None:
  785. embedding_output = torch.cat([embedding_output, visual_embeddings], dim=1)
  786. else:
  787. embedding_output = visual_embeddings
  788. embedding_output = self.LayerNorm(embedding_output)
  789. embedding_output = self.dropout(embedding_output)
  790. elif self.config.has_relative_attention_bias or self.config.has_spatial_attention_bias:
  791. if self.config.has_spatial_attention_bias:
  792. final_bbox = bbox
  793. if self.config.has_relative_attention_bias:
  794. position_ids = self.embeddings.position_ids[:, : input_shape[1]]
  795. position_ids = position_ids.expand_as(input_ids)
  796. final_position_ids = position_ids
  797. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
  798. attention_mask, None, device, dtype=embedding_output.dtype
  799. )
  800. # Prepare head mask if needed
  801. # 1.0 in head_mask indicate we keep the head
  802. # attention_probs has shape bsz x n_heads x N x N
  803. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  804. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  805. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  806. encoder_outputs = self.encoder(
  807. embedding_output,
  808. bbox=final_bbox,
  809. position_ids=final_position_ids,
  810. attention_mask=extended_attention_mask,
  811. head_mask=head_mask,
  812. output_attentions=output_attentions,
  813. output_hidden_states=output_hidden_states,
  814. return_dict=return_dict,
  815. patch_height=patch_height,
  816. patch_width=patch_width,
  817. )
  818. sequence_output = encoder_outputs[0]
  819. if not return_dict:
  820. return (sequence_output,) + encoder_outputs[1:]
  821. return BaseModelOutput(
  822. last_hidden_state=sequence_output,
  823. hidden_states=encoder_outputs.hidden_states,
  824. attentions=encoder_outputs.attentions,
  825. )
  826. class LayoutLMv3ClassificationHead(nn.Module):
  827. """
  828. Head for sentence-level classification tasks. Reference: RobertaClassificationHead
  829. """
  830. def __init__(self, config, pool_feature=False):
  831. super().__init__()
  832. self.pool_feature = pool_feature
  833. if pool_feature:
  834. self.dense = nn.Linear(config.hidden_size * 3, config.hidden_size)
  835. else:
  836. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  837. classifier_dropout = (
  838. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  839. )
  840. self.dropout = nn.Dropout(classifier_dropout)
  841. self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
  842. def forward(self, x):
  843. x = self.dropout(x)
  844. x = self.dense(x)
  845. x = torch.tanh(x)
  846. x = self.dropout(x)
  847. x = self.out_proj(x)
  848. return x
  849. @add_start_docstrings(
  850. """
  851. LayoutLMv3 Model with a token classification head on top (a linear layer on top of the final hidden states) e.g.
  852. for sequence labeling (information extraction) tasks such as [FUNSD](https://guillaumejaume.github.io/FUNSD/),
  853. [SROIE](https://rrc.cvc.uab.es/?ch=13), [CORD](https://github.com/clovaai/cord) and
  854. [Kleister-NDA](https://github.com/applicaai/kleister-nda).
  855. """,
  856. LAYOUTLMV3_START_DOCSTRING,
  857. )
  858. class LayoutLMv3ForTokenClassification(LayoutLMv3PreTrainedModel):
  859. def __init__(self, config):
  860. super().__init__(config)
  861. self.num_labels = config.num_labels
  862. self.layoutlmv3 = LayoutLMv3Model(config)
  863. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  864. if config.num_labels < 10:
  865. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  866. else:
  867. self.classifier = LayoutLMv3ClassificationHead(config, pool_feature=False)
  868. self.init_weights()
  869. @add_start_docstrings_to_model_forward(
  870. LAYOUTLMV3_DOWNSTREAM_INPUTS_DOCSTRING.format("batch_size, sequence_length")
  871. )
  872. @replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC)
  873. def forward(
  874. self,
  875. input_ids: Optional[torch.LongTensor] = None,
  876. bbox: Optional[torch.LongTensor] = None,
  877. attention_mask: Optional[torch.FloatTensor] = None,
  878. token_type_ids: Optional[torch.LongTensor] = None,
  879. position_ids: Optional[torch.LongTensor] = None,
  880. head_mask: Optional[torch.FloatTensor] = None,
  881. inputs_embeds: Optional[torch.FloatTensor] = None,
  882. labels: Optional[torch.LongTensor] = None,
  883. output_attentions: Optional[bool] = None,
  884. output_hidden_states: Optional[bool] = None,
  885. return_dict: Optional[bool] = None,
  886. pixel_values: Optional[torch.LongTensor] = None,
  887. ) -> Union[Tuple, TokenClassifierOutput]:
  888. r"""
  889. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  890. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  891. Returns:
  892. Examples:
  893. ```python
  894. >>> from transformers import AutoProcessor, AutoModelForTokenClassification
  895. >>> from datasets import load_dataset
  896. >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
  897. >>> model = AutoModelForTokenClassification.from_pretrained("microsoft/layoutlmv3-base", num_labels=7)
  898. >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train", trust_remote_code=True)
  899. >>> example = dataset[0]
  900. >>> image = example["image"]
  901. >>> words = example["tokens"]
  902. >>> boxes = example["bboxes"]
  903. >>> word_labels = example["ner_tags"]
  904. >>> encoding = processor(image, words, boxes=boxes, word_labels=word_labels, return_tensors="pt")
  905. >>> outputs = model(**encoding)
  906. >>> loss = outputs.loss
  907. >>> logits = outputs.logits
  908. ```"""
  909. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  910. outputs = self.layoutlmv3(
  911. input_ids,
  912. bbox=bbox,
  913. attention_mask=attention_mask,
  914. token_type_ids=token_type_ids,
  915. position_ids=position_ids,
  916. head_mask=head_mask,
  917. inputs_embeds=inputs_embeds,
  918. output_attentions=output_attentions,
  919. output_hidden_states=output_hidden_states,
  920. return_dict=return_dict,
  921. pixel_values=pixel_values,
  922. )
  923. if input_ids is not None:
  924. input_shape = input_ids.size()
  925. else:
  926. input_shape = inputs_embeds.size()[:-1]
  927. seq_length = input_shape[1]
  928. # only take the text part of the output representations
  929. sequence_output = outputs[0][:, :seq_length]
  930. sequence_output = self.dropout(sequence_output)
  931. logits = self.classifier(sequence_output)
  932. loss = None
  933. if labels is not None:
  934. loss_fct = CrossEntropyLoss()
  935. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  936. if not return_dict:
  937. output = (logits,) + outputs[1:]
  938. return ((loss,) + output) if loss is not None else output
  939. return TokenClassifierOutput(
  940. loss=loss,
  941. logits=logits,
  942. hidden_states=outputs.hidden_states,
  943. attentions=outputs.attentions,
  944. )
  945. @add_start_docstrings(
  946. """
  947. LayoutLMv3 Model with a span classification head on top for extractive question-answering tasks such as
  948. [DocVQA](https://rrc.cvc.uab.es/?ch=17) (a linear layer on top of the text part of the hidden-states output to
  949. compute `span start logits` and `span end logits`).
  950. """,
  951. LAYOUTLMV3_START_DOCSTRING,
  952. )
  953. class LayoutLMv3ForQuestionAnswering(LayoutLMv3PreTrainedModel):
  954. def __init__(self, config):
  955. super().__init__(config)
  956. self.num_labels = config.num_labels
  957. self.layoutlmv3 = LayoutLMv3Model(config)
  958. self.qa_outputs = LayoutLMv3ClassificationHead(config, pool_feature=False)
  959. self.init_weights()
  960. @add_start_docstrings_to_model_forward(
  961. LAYOUTLMV3_DOWNSTREAM_INPUTS_DOCSTRING.format("batch_size, sequence_length")
  962. )
  963. @replace_return_docstrings(output_type=QuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC)
  964. def forward(
  965. self,
  966. input_ids: Optional[torch.LongTensor] = None,
  967. attention_mask: Optional[torch.FloatTensor] = None,
  968. token_type_ids: Optional[torch.LongTensor] = None,
  969. position_ids: Optional[torch.LongTensor] = None,
  970. head_mask: Optional[torch.FloatTensor] = None,
  971. inputs_embeds: Optional[torch.FloatTensor] = None,
  972. start_positions: Optional[torch.LongTensor] = None,
  973. end_positions: Optional[torch.LongTensor] = None,
  974. output_attentions: Optional[bool] = None,
  975. output_hidden_states: Optional[bool] = None,
  976. return_dict: Optional[bool] = None,
  977. bbox: Optional[torch.LongTensor] = None,
  978. pixel_values: Optional[torch.LongTensor] = None,
  979. ) -> Union[Tuple, QuestionAnsweringModelOutput]:
  980. r"""
  981. start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  982. Labels for position (index) of the start of the labelled span for computing the token classification loss.
  983. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
  984. are not taken into account for computing the loss.
  985. end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  986. Labels for position (index) of the end of the labelled span for computing the token classification loss.
  987. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
  988. are not taken into account for computing the loss.
  989. Returns:
  990. Examples:
  991. ```python
  992. >>> from transformers import AutoProcessor, AutoModelForQuestionAnswering
  993. >>> from datasets import load_dataset
  994. >>> import torch
  995. >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
  996. >>> model = AutoModelForQuestionAnswering.from_pretrained("microsoft/layoutlmv3-base")
  997. >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train", trust_remote_code=True)
  998. >>> example = dataset[0]
  999. >>> image = example["image"]
  1000. >>> question = "what's his name?"
  1001. >>> words = example["tokens"]
  1002. >>> boxes = example["bboxes"]
  1003. >>> encoding = processor(image, question, words, boxes=boxes, return_tensors="pt")
  1004. >>> start_positions = torch.tensor([1])
  1005. >>> end_positions = torch.tensor([3])
  1006. >>> outputs = model(**encoding, start_positions=start_positions, end_positions=end_positions)
  1007. >>> loss = outputs.loss
  1008. >>> start_scores = outputs.start_logits
  1009. >>> end_scores = outputs.end_logits
  1010. ```"""
  1011. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1012. outputs = self.layoutlmv3(
  1013. input_ids,
  1014. attention_mask=attention_mask,
  1015. token_type_ids=token_type_ids,
  1016. position_ids=position_ids,
  1017. head_mask=head_mask,
  1018. inputs_embeds=inputs_embeds,
  1019. output_attentions=output_attentions,
  1020. output_hidden_states=output_hidden_states,
  1021. return_dict=return_dict,
  1022. bbox=bbox,
  1023. pixel_values=pixel_values,
  1024. )
  1025. sequence_output = outputs[0]
  1026. logits = self.qa_outputs(sequence_output)
  1027. start_logits, end_logits = logits.split(1, dim=-1)
  1028. start_logits = start_logits.squeeze(-1).contiguous()
  1029. end_logits = end_logits.squeeze(-1).contiguous()
  1030. total_loss = None
  1031. if start_positions is not None and end_positions is not None:
  1032. # If we are on multi-GPU, split add a dimension
  1033. if len(start_positions.size()) > 1:
  1034. start_positions = start_positions.squeeze(-1)
  1035. if len(end_positions.size()) > 1:
  1036. end_positions = end_positions.squeeze(-1)
  1037. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1038. ignored_index = start_logits.size(1)
  1039. start_positions = start_positions.clamp(0, ignored_index)
  1040. end_positions = end_positions.clamp(0, ignored_index)
  1041. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1042. start_loss = loss_fct(start_logits, start_positions)
  1043. end_loss = loss_fct(end_logits, end_positions)
  1044. total_loss = (start_loss + end_loss) / 2
  1045. if not return_dict:
  1046. output = (start_logits, end_logits) + outputs[1:]
  1047. return ((total_loss,) + output) if total_loss is not None else output
  1048. return QuestionAnsweringModelOutput(
  1049. loss=total_loss,
  1050. start_logits=start_logits,
  1051. end_logits=end_logits,
  1052. hidden_states=outputs.hidden_states,
  1053. attentions=outputs.attentions,
  1054. )
  1055. @add_start_docstrings(
  1056. """
  1057. LayoutLMv3 Model with a sequence classification head on top (a linear layer on top of the final hidden state of the
  1058. [CLS] token) e.g. for document image classification tasks such as the
  1059. [RVL-CDIP](https://www.cs.cmu.edu/~aharley/rvl-cdip/) dataset.
  1060. """,
  1061. LAYOUTLMV3_START_DOCSTRING,
  1062. )
  1063. class LayoutLMv3ForSequenceClassification(LayoutLMv3PreTrainedModel):
  1064. def __init__(self, config):
  1065. super().__init__(config)
  1066. self.num_labels = config.num_labels
  1067. self.config = config
  1068. self.layoutlmv3 = LayoutLMv3Model(config)
  1069. self.classifier = LayoutLMv3ClassificationHead(config, pool_feature=False)
  1070. self.init_weights()
  1071. @add_start_docstrings_to_model_forward(
  1072. LAYOUTLMV3_DOWNSTREAM_INPUTS_DOCSTRING.format("batch_size, sequence_length")
  1073. )
  1074. @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
  1075. def forward(
  1076. self,
  1077. input_ids: Optional[torch.LongTensor] = None,
  1078. attention_mask: Optional[torch.FloatTensor] = None,
  1079. token_type_ids: Optional[torch.LongTensor] = None,
  1080. position_ids: Optional[torch.LongTensor] = None,
  1081. head_mask: Optional[torch.FloatTensor] = None,
  1082. inputs_embeds: Optional[torch.FloatTensor] = None,
  1083. labels: Optional[torch.LongTensor] = None,
  1084. output_attentions: Optional[bool] = None,
  1085. output_hidden_states: Optional[bool] = None,
  1086. return_dict: Optional[bool] = None,
  1087. bbox: Optional[torch.LongTensor] = None,
  1088. pixel_values: Optional[torch.LongTensor] = None,
  1089. ) -> Union[Tuple, SequenceClassifierOutput]:
  1090. """
  1091. Returns:
  1092. Examples:
  1093. ```python
  1094. >>> from transformers import AutoProcessor, AutoModelForSequenceClassification
  1095. >>> from datasets import load_dataset
  1096. >>> import torch
  1097. >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
  1098. >>> model = AutoModelForSequenceClassification.from_pretrained("microsoft/layoutlmv3-base")
  1099. >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train", trust_remote_code=True)
  1100. >>> example = dataset[0]
  1101. >>> image = example["image"]
  1102. >>> words = example["tokens"]
  1103. >>> boxes = example["bboxes"]
  1104. >>> encoding = processor(image, words, boxes=boxes, return_tensors="pt")
  1105. >>> sequence_label = torch.tensor([1])
  1106. >>> outputs = model(**encoding, labels=sequence_label)
  1107. >>> loss = outputs.loss
  1108. >>> logits = outputs.logits
  1109. ```"""
  1110. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1111. outputs = self.layoutlmv3(
  1112. input_ids,
  1113. attention_mask=attention_mask,
  1114. token_type_ids=token_type_ids,
  1115. position_ids=position_ids,
  1116. head_mask=head_mask,
  1117. inputs_embeds=inputs_embeds,
  1118. output_attentions=output_attentions,
  1119. output_hidden_states=output_hidden_states,
  1120. return_dict=return_dict,
  1121. bbox=bbox,
  1122. pixel_values=pixel_values,
  1123. )
  1124. sequence_output = outputs[0][:, 0, :]
  1125. logits = self.classifier(sequence_output)
  1126. loss = None
  1127. if labels is not None:
  1128. if self.config.problem_type is None:
  1129. if self.num_labels == 1:
  1130. self.config.problem_type = "regression"
  1131. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  1132. self.config.problem_type = "single_label_classification"
  1133. else:
  1134. self.config.problem_type = "multi_label_classification"
  1135. if self.config.problem_type == "regression":
  1136. loss_fct = MSELoss()
  1137. if self.num_labels == 1:
  1138. loss = loss_fct(logits.squeeze(), labels.squeeze())
  1139. else:
  1140. loss = loss_fct(logits, labels)
  1141. elif self.config.problem_type == "single_label_classification":
  1142. loss_fct = CrossEntropyLoss()
  1143. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1144. elif self.config.problem_type == "multi_label_classification":
  1145. loss_fct = BCEWithLogitsLoss()
  1146. loss = loss_fct(logits, labels)
  1147. if not return_dict:
  1148. output = (logits,) + outputs[1:]
  1149. return ((loss,) + output) if loss is not None else output
  1150. return SequenceClassifierOutput(
  1151. loss=loss,
  1152. logits=logits,
  1153. hidden_states=outputs.hidden_states,
  1154. attentions=outputs.attentions,
  1155. )