modeling_layoutlm.py 60 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375
  1. # coding=utf-8
  2. # Copyright 2018 The Microsoft Research Asia LayoutLM Team Authors and the HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch LayoutLM model."""
  16. import math
  17. from typing import Optional, Tuple, Union
  18. import torch
  19. import torch.utils.checkpoint
  20. from torch import nn
  21. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  22. from ...activations import ACT2FN
  23. from ...modeling_outputs import (
  24. BaseModelOutputWithPastAndCrossAttentions,
  25. BaseModelOutputWithPoolingAndCrossAttentions,
  26. MaskedLMOutput,
  27. QuestionAnsweringModelOutput,
  28. SequenceClassifierOutput,
  29. TokenClassifierOutput,
  30. )
  31. from ...modeling_utils import PreTrainedModel
  32. from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
  33. from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
  34. from .configuration_layoutlm import LayoutLMConfig
  35. logger = logging.get_logger(__name__)
  36. _CONFIG_FOR_DOC = "LayoutLMConfig"
  37. _CHECKPOINT_FOR_DOC = "microsoft/layoutlm-base-uncased"
  38. LayoutLMLayerNorm = nn.LayerNorm
  39. class LayoutLMEmbeddings(nn.Module):
  40. """Construct the embeddings from word, position and token_type embeddings."""
  41. def __init__(self, config):
  42. super(LayoutLMEmbeddings, self).__init__()
  43. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  44. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  45. self.x_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size)
  46. self.y_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size)
  47. self.h_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size)
  48. self.w_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size)
  49. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  50. self.LayerNorm = LayoutLMLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  51. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  52. self.register_buffer(
  53. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  54. )
  55. def forward(
  56. self,
  57. input_ids=None,
  58. bbox=None,
  59. token_type_ids=None,
  60. position_ids=None,
  61. inputs_embeds=None,
  62. ):
  63. if input_ids is not None:
  64. input_shape = input_ids.size()
  65. else:
  66. input_shape = inputs_embeds.size()[:-1]
  67. seq_length = input_shape[1]
  68. device = input_ids.device if input_ids is not None else inputs_embeds.device
  69. if position_ids is None:
  70. position_ids = self.position_ids[:, :seq_length]
  71. if token_type_ids is None:
  72. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  73. if inputs_embeds is None:
  74. inputs_embeds = self.word_embeddings(input_ids)
  75. words_embeddings = inputs_embeds
  76. position_embeddings = self.position_embeddings(position_ids)
  77. try:
  78. left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0])
  79. upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1])
  80. right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2])
  81. lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3])
  82. except IndexError as e:
  83. raise IndexError("The `bbox`coordinate values should be within 0-1000 range.") from e
  84. h_position_embeddings = self.h_position_embeddings(bbox[:, :, 3] - bbox[:, :, 1])
  85. w_position_embeddings = self.w_position_embeddings(bbox[:, :, 2] - bbox[:, :, 0])
  86. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  87. embeddings = (
  88. words_embeddings
  89. + position_embeddings
  90. + left_position_embeddings
  91. + upper_position_embeddings
  92. + right_position_embeddings
  93. + lower_position_embeddings
  94. + h_position_embeddings
  95. + w_position_embeddings
  96. + token_type_embeddings
  97. )
  98. embeddings = self.LayerNorm(embeddings)
  99. embeddings = self.dropout(embeddings)
  100. return embeddings
  101. # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->LayoutLM
  102. class LayoutLMSelfAttention(nn.Module):
  103. def __init__(self, config, position_embedding_type=None):
  104. super().__init__()
  105. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  106. raise ValueError(
  107. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  108. f"heads ({config.num_attention_heads})"
  109. )
  110. self.num_attention_heads = config.num_attention_heads
  111. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  112. self.all_head_size = self.num_attention_heads * self.attention_head_size
  113. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  114. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  115. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  116. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  117. self.position_embedding_type = position_embedding_type or getattr(
  118. config, "position_embedding_type", "absolute"
  119. )
  120. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  121. self.max_position_embeddings = config.max_position_embeddings
  122. self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
  123. self.is_decoder = config.is_decoder
  124. def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
  125. new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
  126. x = x.view(new_x_shape)
  127. return x.permute(0, 2, 1, 3)
  128. def forward(
  129. self,
  130. hidden_states: torch.Tensor,
  131. attention_mask: Optional[torch.FloatTensor] = None,
  132. head_mask: Optional[torch.FloatTensor] = None,
  133. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  134. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  135. past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
  136. output_attentions: Optional[bool] = False,
  137. ) -> Tuple[torch.Tensor]:
  138. mixed_query_layer = self.query(hidden_states)
  139. # If this is instantiated as a cross-attention module, the keys
  140. # and values come from an encoder; the attention mask needs to be
  141. # such that the encoder's padding tokens are not attended to.
  142. is_cross_attention = encoder_hidden_states is not None
  143. if is_cross_attention and past_key_value is not None:
  144. # reuse k,v, cross_attentions
  145. key_layer = past_key_value[0]
  146. value_layer = past_key_value[1]
  147. attention_mask = encoder_attention_mask
  148. elif is_cross_attention:
  149. key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
  150. value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
  151. attention_mask = encoder_attention_mask
  152. elif past_key_value is not None:
  153. key_layer = self.transpose_for_scores(self.key(hidden_states))
  154. value_layer = self.transpose_for_scores(self.value(hidden_states))
  155. key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
  156. value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
  157. else:
  158. key_layer = self.transpose_for_scores(self.key(hidden_states))
  159. value_layer = self.transpose_for_scores(self.value(hidden_states))
  160. query_layer = self.transpose_for_scores(mixed_query_layer)
  161. use_cache = past_key_value is not None
  162. if self.is_decoder:
  163. # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
  164. # Further calls to cross_attention layer can then reuse all cross-attention
  165. # key/value_states (first "if" case)
  166. # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
  167. # all previous decoder key/value_states. Further calls to uni-directional self-attention
  168. # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
  169. # if encoder bi-directional self-attention `past_key_value` is always `None`
  170. past_key_value = (key_layer, value_layer)
  171. # Take the dot product between "query" and "key" to get the raw attention scores.
  172. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  173. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  174. query_length, key_length = query_layer.shape[2], key_layer.shape[2]
  175. if use_cache:
  176. position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
  177. -1, 1
  178. )
  179. else:
  180. position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
  181. position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
  182. distance = position_ids_l - position_ids_r
  183. positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
  184. positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
  185. if self.position_embedding_type == "relative_key":
  186. relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  187. attention_scores = attention_scores + relative_position_scores
  188. elif self.position_embedding_type == "relative_key_query":
  189. relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  190. relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
  191. attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
  192. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  193. if attention_mask is not None:
  194. # Apply the attention mask is (precomputed for all layers in LayoutLMModel forward() function)
  195. attention_scores = attention_scores + attention_mask
  196. # Normalize the attention scores to probabilities.
  197. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  198. # This is actually dropping out entire tokens to attend to, which might
  199. # seem a bit unusual, but is taken from the original Transformer paper.
  200. attention_probs = self.dropout(attention_probs)
  201. # Mask heads if we want to
  202. if head_mask is not None:
  203. attention_probs = attention_probs * head_mask
  204. context_layer = torch.matmul(attention_probs, value_layer)
  205. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  206. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  207. context_layer = context_layer.view(new_context_layer_shape)
  208. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  209. if self.is_decoder:
  210. outputs = outputs + (past_key_value,)
  211. return outputs
  212. # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->LayoutLM
  213. class LayoutLMSelfOutput(nn.Module):
  214. def __init__(self, config):
  215. super().__init__()
  216. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  217. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  218. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  219. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  220. hidden_states = self.dense(hidden_states)
  221. hidden_states = self.dropout(hidden_states)
  222. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  223. return hidden_states
  224. LAYOUTLM_SELF_ATTENTION_CLASSES = {
  225. "eager": LayoutLMSelfAttention,
  226. }
  227. # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->LayoutLM,BERT->LAYOUTLM
  228. class LayoutLMAttention(nn.Module):
  229. def __init__(self, config, position_embedding_type=None):
  230. super().__init__()
  231. self.self = LAYOUTLM_SELF_ATTENTION_CLASSES[config._attn_implementation](
  232. config, position_embedding_type=position_embedding_type
  233. )
  234. self.output = LayoutLMSelfOutput(config)
  235. self.pruned_heads = set()
  236. def prune_heads(self, heads):
  237. if len(heads) == 0:
  238. return
  239. heads, index = find_pruneable_heads_and_indices(
  240. heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
  241. )
  242. # Prune linear layers
  243. self.self.query = prune_linear_layer(self.self.query, index)
  244. self.self.key = prune_linear_layer(self.self.key, index)
  245. self.self.value = prune_linear_layer(self.self.value, index)
  246. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  247. # Update hyper params and store pruned heads
  248. self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
  249. self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
  250. self.pruned_heads = self.pruned_heads.union(heads)
  251. def forward(
  252. self,
  253. hidden_states: torch.Tensor,
  254. attention_mask: Optional[torch.FloatTensor] = None,
  255. head_mask: Optional[torch.FloatTensor] = None,
  256. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  257. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  258. past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
  259. output_attentions: Optional[bool] = False,
  260. ) -> Tuple[torch.Tensor]:
  261. self_outputs = self.self(
  262. hidden_states,
  263. attention_mask,
  264. head_mask,
  265. encoder_hidden_states,
  266. encoder_attention_mask,
  267. past_key_value,
  268. output_attentions,
  269. )
  270. attention_output = self.output(self_outputs[0], hidden_states)
  271. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  272. return outputs
  273. # Copied from transformers.models.bert.modeling_bert.BertIntermediate
  274. class LayoutLMIntermediate(nn.Module):
  275. def __init__(self, config):
  276. super().__init__()
  277. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  278. if isinstance(config.hidden_act, str):
  279. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  280. else:
  281. self.intermediate_act_fn = config.hidden_act
  282. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  283. hidden_states = self.dense(hidden_states)
  284. hidden_states = self.intermediate_act_fn(hidden_states)
  285. return hidden_states
  286. # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->LayoutLM
  287. class LayoutLMOutput(nn.Module):
  288. def __init__(self, config):
  289. super().__init__()
  290. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  291. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  292. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  293. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  294. hidden_states = self.dense(hidden_states)
  295. hidden_states = self.dropout(hidden_states)
  296. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  297. return hidden_states
  298. # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->LayoutLM
  299. class LayoutLMLayer(nn.Module):
  300. def __init__(self, config):
  301. super().__init__()
  302. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  303. self.seq_len_dim = 1
  304. self.attention = LayoutLMAttention(config)
  305. self.is_decoder = config.is_decoder
  306. self.add_cross_attention = config.add_cross_attention
  307. if self.add_cross_attention:
  308. if not self.is_decoder:
  309. raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
  310. self.crossattention = LayoutLMAttention(config, position_embedding_type="absolute")
  311. self.intermediate = LayoutLMIntermediate(config)
  312. self.output = LayoutLMOutput(config)
  313. def forward(
  314. self,
  315. hidden_states: torch.Tensor,
  316. attention_mask: Optional[torch.FloatTensor] = None,
  317. head_mask: Optional[torch.FloatTensor] = None,
  318. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  319. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  320. past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
  321. output_attentions: Optional[bool] = False,
  322. ) -> Tuple[torch.Tensor]:
  323. # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
  324. self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
  325. self_attention_outputs = self.attention(
  326. hidden_states,
  327. attention_mask,
  328. head_mask,
  329. output_attentions=output_attentions,
  330. past_key_value=self_attn_past_key_value,
  331. )
  332. attention_output = self_attention_outputs[0]
  333. # if decoder, the last output is tuple of self-attn cache
  334. if self.is_decoder:
  335. outputs = self_attention_outputs[1:-1]
  336. present_key_value = self_attention_outputs[-1]
  337. else:
  338. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  339. cross_attn_present_key_value = None
  340. if self.is_decoder and encoder_hidden_states is not None:
  341. if not hasattr(self, "crossattention"):
  342. raise ValueError(
  343. f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
  344. " by setting `config.add_cross_attention=True`"
  345. )
  346. # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
  347. cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
  348. cross_attention_outputs = self.crossattention(
  349. attention_output,
  350. attention_mask,
  351. head_mask,
  352. encoder_hidden_states,
  353. encoder_attention_mask,
  354. cross_attn_past_key_value,
  355. output_attentions,
  356. )
  357. attention_output = cross_attention_outputs[0]
  358. outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
  359. # add cross-attn cache to positions 3,4 of present_key_value tuple
  360. cross_attn_present_key_value = cross_attention_outputs[-1]
  361. present_key_value = present_key_value + cross_attn_present_key_value
  362. layer_output = apply_chunking_to_forward(
  363. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  364. )
  365. outputs = (layer_output,) + outputs
  366. # if decoder, return the attn key/values as the last output
  367. if self.is_decoder:
  368. outputs = outputs + (present_key_value,)
  369. return outputs
  370. def feed_forward_chunk(self, attention_output):
  371. intermediate_output = self.intermediate(attention_output)
  372. layer_output = self.output(intermediate_output, attention_output)
  373. return layer_output
  374. # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->LayoutLM
  375. class LayoutLMEncoder(nn.Module):
  376. def __init__(self, config):
  377. super().__init__()
  378. self.config = config
  379. self.layer = nn.ModuleList([LayoutLMLayer(config) for _ in range(config.num_hidden_layers)])
  380. self.gradient_checkpointing = False
  381. def forward(
  382. self,
  383. hidden_states: torch.Tensor,
  384. attention_mask: Optional[torch.FloatTensor] = None,
  385. head_mask: Optional[torch.FloatTensor] = None,
  386. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  387. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  388. past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
  389. use_cache: Optional[bool] = None,
  390. output_attentions: Optional[bool] = False,
  391. output_hidden_states: Optional[bool] = False,
  392. return_dict: Optional[bool] = True,
  393. ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
  394. all_hidden_states = () if output_hidden_states else None
  395. all_self_attentions = () if output_attentions else None
  396. all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
  397. if self.gradient_checkpointing and self.training:
  398. if use_cache:
  399. logger.warning_once(
  400. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  401. )
  402. use_cache = False
  403. next_decoder_cache = () if use_cache else None
  404. for i, layer_module in enumerate(self.layer):
  405. if output_hidden_states:
  406. all_hidden_states = all_hidden_states + (hidden_states,)
  407. layer_head_mask = head_mask[i] if head_mask is not None else None
  408. past_key_value = past_key_values[i] if past_key_values is not None else None
  409. if self.gradient_checkpointing and self.training:
  410. layer_outputs = self._gradient_checkpointing_func(
  411. layer_module.__call__,
  412. hidden_states,
  413. attention_mask,
  414. layer_head_mask,
  415. encoder_hidden_states,
  416. encoder_attention_mask,
  417. past_key_value,
  418. output_attentions,
  419. )
  420. else:
  421. layer_outputs = layer_module(
  422. hidden_states,
  423. attention_mask,
  424. layer_head_mask,
  425. encoder_hidden_states,
  426. encoder_attention_mask,
  427. past_key_value,
  428. output_attentions,
  429. )
  430. hidden_states = layer_outputs[0]
  431. if use_cache:
  432. next_decoder_cache += (layer_outputs[-1],)
  433. if output_attentions:
  434. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  435. if self.config.add_cross_attention:
  436. all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
  437. if output_hidden_states:
  438. all_hidden_states = all_hidden_states + (hidden_states,)
  439. if not return_dict:
  440. return tuple(
  441. v
  442. for v in [
  443. hidden_states,
  444. next_decoder_cache,
  445. all_hidden_states,
  446. all_self_attentions,
  447. all_cross_attentions,
  448. ]
  449. if v is not None
  450. )
  451. return BaseModelOutputWithPastAndCrossAttentions(
  452. last_hidden_state=hidden_states,
  453. past_key_values=next_decoder_cache,
  454. hidden_states=all_hidden_states,
  455. attentions=all_self_attentions,
  456. cross_attentions=all_cross_attentions,
  457. )
  458. # Copied from transformers.models.bert.modeling_bert.BertPooler
  459. class LayoutLMPooler(nn.Module):
  460. def __init__(self, config):
  461. super().__init__()
  462. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  463. self.activation = nn.Tanh()
  464. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  465. # We "pool" the model by simply taking the hidden state corresponding
  466. # to the first token.
  467. first_token_tensor = hidden_states[:, 0]
  468. pooled_output = self.dense(first_token_tensor)
  469. pooled_output = self.activation(pooled_output)
  470. return pooled_output
  471. # Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->LayoutLM
  472. class LayoutLMPredictionHeadTransform(nn.Module):
  473. def __init__(self, config):
  474. super().__init__()
  475. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  476. if isinstance(config.hidden_act, str):
  477. self.transform_act_fn = ACT2FN[config.hidden_act]
  478. else:
  479. self.transform_act_fn = config.hidden_act
  480. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  481. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  482. hidden_states = self.dense(hidden_states)
  483. hidden_states = self.transform_act_fn(hidden_states)
  484. hidden_states = self.LayerNorm(hidden_states)
  485. return hidden_states
  486. # Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->LayoutLM
  487. class LayoutLMLMPredictionHead(nn.Module):
  488. def __init__(self, config):
  489. super().__init__()
  490. self.transform = LayoutLMPredictionHeadTransform(config)
  491. # The output weights are the same as the input embeddings, but there is
  492. # an output-only bias for each token.
  493. self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  494. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  495. # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
  496. self.decoder.bias = self.bias
  497. def _tie_weights(self):
  498. self.decoder.bias = self.bias
  499. def forward(self, hidden_states):
  500. hidden_states = self.transform(hidden_states)
  501. hidden_states = self.decoder(hidden_states)
  502. return hidden_states
  503. # Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->LayoutLM
  504. class LayoutLMOnlyMLMHead(nn.Module):
  505. def __init__(self, config):
  506. super().__init__()
  507. self.predictions = LayoutLMLMPredictionHead(config)
  508. def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
  509. prediction_scores = self.predictions(sequence_output)
  510. return prediction_scores
  511. class LayoutLMPreTrainedModel(PreTrainedModel):
  512. """
  513. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  514. models.
  515. """
  516. config_class = LayoutLMConfig
  517. base_model_prefix = "layoutlm"
  518. supports_gradient_checkpointing = True
  519. def _init_weights(self, module):
  520. """Initialize the weights"""
  521. if isinstance(module, nn.Linear):
  522. # Slightly different from the TF version which uses truncated_normal for initialization
  523. # cf https://github.com/pytorch/pytorch/pull/5617
  524. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  525. if module.bias is not None:
  526. module.bias.data.zero_()
  527. elif isinstance(module, nn.Embedding):
  528. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  529. if module.padding_idx is not None:
  530. module.weight.data[module.padding_idx].zero_()
  531. elif isinstance(module, LayoutLMLayerNorm):
  532. module.bias.data.zero_()
  533. module.weight.data.fill_(1.0)
  534. LAYOUTLM_START_DOCSTRING = r"""
  535. The LayoutLM model was proposed in [LayoutLM: Pre-training of Text and Layout for Document Image
  536. Understanding](https://arxiv.org/abs/1912.13318) by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei and
  537. Ming Zhou.
  538. This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
  539. it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
  540. behavior.
  541. Parameters:
  542. config ([`LayoutLMConfig`]): Model configuration class with all the parameters of the model.
  543. Initializing with a config file does not load the weights associated with the model, only the
  544. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  545. """
  546. LAYOUTLM_INPUTS_DOCSTRING = r"""
  547. Args:
  548. input_ids (`torch.LongTensor` of shape `({0})`):
  549. Indices of input sequence tokens in the vocabulary.
  550. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  551. [`PreTrainedTokenizer.__call__`] for details.
  552. [What are input IDs?](../glossary#input-ids)
  553. bbox (`torch.LongTensor` of shape `({0}, 4)`, *optional*):
  554. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  555. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  556. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  557. y1) represents the position of the lower right corner. See [Overview](#Overview) for normalization.
  558. attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
  559. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: `1` for
  560. tokens that are NOT MASKED, `0` for MASKED tokens.
  561. [What are attention masks?](../glossary#attention-mask)
  562. token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
  563. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  564. 1]`: `0` corresponds to a *sentence A* token, `1` corresponds to a *sentence B* token
  565. [What are token type IDs?](../glossary#token-type-ids)
  566. position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
  567. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  568. config.max_position_embeddings - 1]`.
  569. [What are position IDs?](../glossary#position-ids)
  570. head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  571. Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: `1`
  572. indicates the head is **not masked**, `0` indicates the head is **masked**.
  573. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  574. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  575. is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
  576. model's internal embedding lookup matrix.
  577. output_attentions (`bool`, *optional*):
  578. If set to `True`, the attentions tensors of all attention layers are returned. See `attentions` under
  579. returned tensors for more detail.
  580. output_hidden_states (`bool`, *optional*):
  581. If set to `True`, the hidden states of all layers are returned. See `hidden_states` under returned tensors
  582. for more detail.
  583. return_dict (`bool`, *optional*):
  584. If set to `True`, the model will return a [`~utils.ModelOutput`] instead of a plain tuple.
  585. """
  586. @add_start_docstrings(
  587. "The bare LayoutLM Model transformer outputting raw hidden-states without any specific head on top.",
  588. LAYOUTLM_START_DOCSTRING,
  589. )
  590. class LayoutLMModel(LayoutLMPreTrainedModel):
  591. def __init__(self, config):
  592. super(LayoutLMModel, self).__init__(config)
  593. self.config = config
  594. self.embeddings = LayoutLMEmbeddings(config)
  595. self.encoder = LayoutLMEncoder(config)
  596. self.pooler = LayoutLMPooler(config)
  597. # Initialize weights and apply final processing
  598. self.post_init()
  599. def get_input_embeddings(self):
  600. return self.embeddings.word_embeddings
  601. def set_input_embeddings(self, value):
  602. self.embeddings.word_embeddings = value
  603. def _prune_heads(self, heads_to_prune):
  604. """
  605. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  606. class PreTrainedModel
  607. """
  608. for layer, heads in heads_to_prune.items():
  609. self.encoder.layer[layer].attention.prune_heads(heads)
  610. @add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  611. @replace_return_docstrings(output_type=BaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC)
  612. def forward(
  613. self,
  614. input_ids: Optional[torch.LongTensor] = None,
  615. bbox: Optional[torch.LongTensor] = None,
  616. attention_mask: Optional[torch.FloatTensor] = None,
  617. token_type_ids: Optional[torch.LongTensor] = None,
  618. position_ids: Optional[torch.LongTensor] = None,
  619. head_mask: Optional[torch.FloatTensor] = None,
  620. inputs_embeds: Optional[torch.FloatTensor] = None,
  621. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  622. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  623. output_attentions: Optional[bool] = None,
  624. output_hidden_states: Optional[bool] = None,
  625. return_dict: Optional[bool] = None,
  626. ) -> Union[Tuple, BaseModelOutputWithPoolingAndCrossAttentions]:
  627. r"""
  628. Returns:
  629. Examples:
  630. ```python
  631. >>> from transformers import AutoTokenizer, LayoutLMModel
  632. >>> import torch
  633. >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
  634. >>> model = LayoutLMModel.from_pretrained("microsoft/layoutlm-base-uncased")
  635. >>> words = ["Hello", "world"]
  636. >>> normalized_word_boxes = [637, 773, 693, 782], [698, 773, 733, 782]
  637. >>> token_boxes = []
  638. >>> for word, box in zip(words, normalized_word_boxes):
  639. ... word_tokens = tokenizer.tokenize(word)
  640. ... token_boxes.extend([box] * len(word_tokens))
  641. >>> # add bounding boxes of cls + sep tokens
  642. >>> token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]]
  643. >>> encoding = tokenizer(" ".join(words), return_tensors="pt")
  644. >>> input_ids = encoding["input_ids"]
  645. >>> attention_mask = encoding["attention_mask"]
  646. >>> token_type_ids = encoding["token_type_ids"]
  647. >>> bbox = torch.tensor([token_boxes])
  648. >>> outputs = model(
  649. ... input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids
  650. ... )
  651. >>> last_hidden_states = outputs.last_hidden_state
  652. ```"""
  653. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  654. output_hidden_states = (
  655. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  656. )
  657. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  658. if input_ids is not None and inputs_embeds is not None:
  659. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  660. elif input_ids is not None:
  661. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  662. input_shape = input_ids.size()
  663. elif inputs_embeds is not None:
  664. input_shape = inputs_embeds.size()[:-1]
  665. else:
  666. raise ValueError("You have to specify either input_ids or inputs_embeds")
  667. device = input_ids.device if input_ids is not None else inputs_embeds.device
  668. if attention_mask is None:
  669. attention_mask = torch.ones(input_shape, device=device)
  670. if token_type_ids is None:
  671. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  672. if bbox is None:
  673. bbox = torch.zeros(input_shape + (4,), dtype=torch.long, device=device)
  674. extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
  675. extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)
  676. extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min
  677. if head_mask is not None:
  678. if head_mask.dim() == 1:
  679. head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
  680. head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
  681. elif head_mask.dim() == 2:
  682. head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
  683. head_mask = head_mask.to(dtype=next(self.parameters()).dtype)
  684. else:
  685. head_mask = [None] * self.config.num_hidden_layers
  686. embedding_output = self.embeddings(
  687. input_ids=input_ids,
  688. bbox=bbox,
  689. position_ids=position_ids,
  690. token_type_ids=token_type_ids,
  691. inputs_embeds=inputs_embeds,
  692. )
  693. encoder_outputs = self.encoder(
  694. embedding_output,
  695. extended_attention_mask,
  696. head_mask=head_mask,
  697. output_attentions=output_attentions,
  698. output_hidden_states=output_hidden_states,
  699. return_dict=return_dict,
  700. )
  701. sequence_output = encoder_outputs[0]
  702. pooled_output = self.pooler(sequence_output)
  703. if not return_dict:
  704. return (sequence_output, pooled_output) + encoder_outputs[1:]
  705. return BaseModelOutputWithPoolingAndCrossAttentions(
  706. last_hidden_state=sequence_output,
  707. pooler_output=pooled_output,
  708. hidden_states=encoder_outputs.hidden_states,
  709. attentions=encoder_outputs.attentions,
  710. cross_attentions=encoder_outputs.cross_attentions,
  711. )
  712. @add_start_docstrings("""LayoutLM Model with a `language modeling` head on top.""", LAYOUTLM_START_DOCSTRING)
  713. class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
  714. _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
  715. def __init__(self, config):
  716. super().__init__(config)
  717. self.layoutlm = LayoutLMModel(config)
  718. self.cls = LayoutLMOnlyMLMHead(config)
  719. # Initialize weights and apply final processing
  720. self.post_init()
  721. def get_input_embeddings(self):
  722. return self.layoutlm.embeddings.word_embeddings
  723. def get_output_embeddings(self):
  724. return self.cls.predictions.decoder
  725. def set_output_embeddings(self, new_embeddings):
  726. self.cls.predictions.decoder = new_embeddings
  727. self.cls.predictions.bias = new_embeddings.bias
  728. @add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  729. @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
  730. def forward(
  731. self,
  732. input_ids: Optional[torch.LongTensor] = None,
  733. bbox: Optional[torch.LongTensor] = None,
  734. attention_mask: Optional[torch.FloatTensor] = None,
  735. token_type_ids: Optional[torch.LongTensor] = None,
  736. position_ids: Optional[torch.LongTensor] = None,
  737. head_mask: Optional[torch.FloatTensor] = None,
  738. inputs_embeds: Optional[torch.FloatTensor] = None,
  739. labels: Optional[torch.LongTensor] = None,
  740. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  741. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  742. output_attentions: Optional[bool] = None,
  743. output_hidden_states: Optional[bool] = None,
  744. return_dict: Optional[bool] = None,
  745. ) -> Union[Tuple, MaskedLMOutput]:
  746. r"""
  747. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  748. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  749. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  750. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  751. Returns:
  752. Examples:
  753. ```python
  754. >>> from transformers import AutoTokenizer, LayoutLMForMaskedLM
  755. >>> import torch
  756. >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
  757. >>> model = LayoutLMForMaskedLM.from_pretrained("microsoft/layoutlm-base-uncased")
  758. >>> words = ["Hello", "[MASK]"]
  759. >>> normalized_word_boxes = [637, 773, 693, 782], [698, 773, 733, 782]
  760. >>> token_boxes = []
  761. >>> for word, box in zip(words, normalized_word_boxes):
  762. ... word_tokens = tokenizer.tokenize(word)
  763. ... token_boxes.extend([box] * len(word_tokens))
  764. >>> # add bounding boxes of cls + sep tokens
  765. >>> token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]]
  766. >>> encoding = tokenizer(" ".join(words), return_tensors="pt")
  767. >>> input_ids = encoding["input_ids"]
  768. >>> attention_mask = encoding["attention_mask"]
  769. >>> token_type_ids = encoding["token_type_ids"]
  770. >>> bbox = torch.tensor([token_boxes])
  771. >>> labels = tokenizer("Hello world", return_tensors="pt")["input_ids"]
  772. >>> outputs = model(
  773. ... input_ids=input_ids,
  774. ... bbox=bbox,
  775. ... attention_mask=attention_mask,
  776. ... token_type_ids=token_type_ids,
  777. ... labels=labels,
  778. ... )
  779. >>> loss = outputs.loss
  780. ```"""
  781. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  782. outputs = self.layoutlm(
  783. input_ids,
  784. bbox,
  785. attention_mask=attention_mask,
  786. token_type_ids=token_type_ids,
  787. position_ids=position_ids,
  788. head_mask=head_mask,
  789. inputs_embeds=inputs_embeds,
  790. encoder_hidden_states=encoder_hidden_states,
  791. encoder_attention_mask=encoder_attention_mask,
  792. output_attentions=output_attentions,
  793. output_hidden_states=output_hidden_states,
  794. return_dict=return_dict,
  795. )
  796. sequence_output = outputs[0]
  797. prediction_scores = self.cls(sequence_output)
  798. masked_lm_loss = None
  799. if labels is not None:
  800. loss_fct = CrossEntropyLoss()
  801. masked_lm_loss = loss_fct(
  802. prediction_scores.view(-1, self.config.vocab_size),
  803. labels.view(-1),
  804. )
  805. if not return_dict:
  806. output = (prediction_scores,) + outputs[2:]
  807. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  808. return MaskedLMOutput(
  809. loss=masked_lm_loss,
  810. logits=prediction_scores,
  811. hidden_states=outputs.hidden_states,
  812. attentions=outputs.attentions,
  813. )
  814. @add_start_docstrings(
  815. """
  816. LayoutLM Model with a sequence classification head on top (a linear layer on top of the pooled output) e.g. for
  817. document image classification tasks such as the [RVL-CDIP](https://www.cs.cmu.edu/~aharley/rvl-cdip/) dataset.
  818. """,
  819. LAYOUTLM_START_DOCSTRING,
  820. )
  821. class LayoutLMForSequenceClassification(LayoutLMPreTrainedModel):
  822. def __init__(self, config):
  823. super().__init__(config)
  824. self.num_labels = config.num_labels
  825. self.layoutlm = LayoutLMModel(config)
  826. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  827. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  828. # Initialize weights and apply final processing
  829. self.post_init()
  830. def get_input_embeddings(self):
  831. return self.layoutlm.embeddings.word_embeddings
  832. @add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  833. @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
  834. def forward(
  835. self,
  836. input_ids: Optional[torch.LongTensor] = None,
  837. bbox: Optional[torch.LongTensor] = None,
  838. attention_mask: Optional[torch.FloatTensor] = None,
  839. token_type_ids: Optional[torch.LongTensor] = None,
  840. position_ids: Optional[torch.LongTensor] = None,
  841. head_mask: Optional[torch.FloatTensor] = None,
  842. inputs_embeds: Optional[torch.FloatTensor] = None,
  843. labels: Optional[torch.LongTensor] = None,
  844. output_attentions: Optional[bool] = None,
  845. output_hidden_states: Optional[bool] = None,
  846. return_dict: Optional[bool] = None,
  847. ) -> Union[Tuple, SequenceClassifierOutput]:
  848. r"""
  849. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  850. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  851. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  852. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  853. Returns:
  854. Examples:
  855. ```python
  856. >>> from transformers import AutoTokenizer, LayoutLMForSequenceClassification
  857. >>> import torch
  858. >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
  859. >>> model = LayoutLMForSequenceClassification.from_pretrained("microsoft/layoutlm-base-uncased")
  860. >>> words = ["Hello", "world"]
  861. >>> normalized_word_boxes = [637, 773, 693, 782], [698, 773, 733, 782]
  862. >>> token_boxes = []
  863. >>> for word, box in zip(words, normalized_word_boxes):
  864. ... word_tokens = tokenizer.tokenize(word)
  865. ... token_boxes.extend([box] * len(word_tokens))
  866. >>> # add bounding boxes of cls + sep tokens
  867. >>> token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]]
  868. >>> encoding = tokenizer(" ".join(words), return_tensors="pt")
  869. >>> input_ids = encoding["input_ids"]
  870. >>> attention_mask = encoding["attention_mask"]
  871. >>> token_type_ids = encoding["token_type_ids"]
  872. >>> bbox = torch.tensor([token_boxes])
  873. >>> sequence_label = torch.tensor([1])
  874. >>> outputs = model(
  875. ... input_ids=input_ids,
  876. ... bbox=bbox,
  877. ... attention_mask=attention_mask,
  878. ... token_type_ids=token_type_ids,
  879. ... labels=sequence_label,
  880. ... )
  881. >>> loss = outputs.loss
  882. >>> logits = outputs.logits
  883. ```"""
  884. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  885. outputs = self.layoutlm(
  886. input_ids=input_ids,
  887. bbox=bbox,
  888. attention_mask=attention_mask,
  889. token_type_ids=token_type_ids,
  890. position_ids=position_ids,
  891. head_mask=head_mask,
  892. inputs_embeds=inputs_embeds,
  893. output_attentions=output_attentions,
  894. output_hidden_states=output_hidden_states,
  895. return_dict=return_dict,
  896. )
  897. pooled_output = outputs[1]
  898. pooled_output = self.dropout(pooled_output)
  899. logits = self.classifier(pooled_output)
  900. loss = None
  901. if labels is not None:
  902. if self.config.problem_type is None:
  903. if self.num_labels == 1:
  904. self.config.problem_type = "regression"
  905. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  906. self.config.problem_type = "single_label_classification"
  907. else:
  908. self.config.problem_type = "multi_label_classification"
  909. if self.config.problem_type == "regression":
  910. loss_fct = MSELoss()
  911. if self.num_labels == 1:
  912. loss = loss_fct(logits.squeeze(), labels.squeeze())
  913. else:
  914. loss = loss_fct(logits, labels)
  915. elif self.config.problem_type == "single_label_classification":
  916. loss_fct = CrossEntropyLoss()
  917. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  918. elif self.config.problem_type == "multi_label_classification":
  919. loss_fct = BCEWithLogitsLoss()
  920. loss = loss_fct(logits, labels)
  921. if not return_dict:
  922. output = (logits,) + outputs[2:]
  923. return ((loss,) + output) if loss is not None else output
  924. return SequenceClassifierOutput(
  925. loss=loss,
  926. logits=logits,
  927. hidden_states=outputs.hidden_states,
  928. attentions=outputs.attentions,
  929. )
  930. @add_start_docstrings(
  931. """
  932. LayoutLM Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
  933. sequence labeling (information extraction) tasks such as the [FUNSD](https://guillaumejaume.github.io/FUNSD/)
  934. dataset and the [SROIE](https://rrc.cvc.uab.es/?ch=13) dataset.
  935. """,
  936. LAYOUTLM_START_DOCSTRING,
  937. )
  938. class LayoutLMForTokenClassification(LayoutLMPreTrainedModel):
  939. def __init__(self, config):
  940. super().__init__(config)
  941. self.num_labels = config.num_labels
  942. self.layoutlm = LayoutLMModel(config)
  943. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  944. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  945. # Initialize weights and apply final processing
  946. self.post_init()
  947. def get_input_embeddings(self):
  948. return self.layoutlm.embeddings.word_embeddings
  949. @add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  950. @replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC)
  951. def forward(
  952. self,
  953. input_ids: Optional[torch.LongTensor] = None,
  954. bbox: Optional[torch.LongTensor] = None,
  955. attention_mask: Optional[torch.FloatTensor] = None,
  956. token_type_ids: Optional[torch.LongTensor] = None,
  957. position_ids: Optional[torch.LongTensor] = None,
  958. head_mask: Optional[torch.FloatTensor] = None,
  959. inputs_embeds: Optional[torch.FloatTensor] = None,
  960. labels: Optional[torch.LongTensor] = None,
  961. output_attentions: Optional[bool] = None,
  962. output_hidden_states: Optional[bool] = None,
  963. return_dict: Optional[bool] = None,
  964. ) -> Union[Tuple, TokenClassifierOutput]:
  965. r"""
  966. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  967. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  968. Returns:
  969. Examples:
  970. ```python
  971. >>> from transformers import AutoTokenizer, LayoutLMForTokenClassification
  972. >>> import torch
  973. >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
  974. >>> model = LayoutLMForTokenClassification.from_pretrained("microsoft/layoutlm-base-uncased")
  975. >>> words = ["Hello", "world"]
  976. >>> normalized_word_boxes = [637, 773, 693, 782], [698, 773, 733, 782]
  977. >>> token_boxes = []
  978. >>> for word, box in zip(words, normalized_word_boxes):
  979. ... word_tokens = tokenizer.tokenize(word)
  980. ... token_boxes.extend([box] * len(word_tokens))
  981. >>> # add bounding boxes of cls + sep tokens
  982. >>> token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]]
  983. >>> encoding = tokenizer(" ".join(words), return_tensors="pt")
  984. >>> input_ids = encoding["input_ids"]
  985. >>> attention_mask = encoding["attention_mask"]
  986. >>> token_type_ids = encoding["token_type_ids"]
  987. >>> bbox = torch.tensor([token_boxes])
  988. >>> token_labels = torch.tensor([1, 1, 0, 0]).unsqueeze(0) # batch size of 1
  989. >>> outputs = model(
  990. ... input_ids=input_ids,
  991. ... bbox=bbox,
  992. ... attention_mask=attention_mask,
  993. ... token_type_ids=token_type_ids,
  994. ... labels=token_labels,
  995. ... )
  996. >>> loss = outputs.loss
  997. >>> logits = outputs.logits
  998. ```"""
  999. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1000. outputs = self.layoutlm(
  1001. input_ids=input_ids,
  1002. bbox=bbox,
  1003. attention_mask=attention_mask,
  1004. token_type_ids=token_type_ids,
  1005. position_ids=position_ids,
  1006. head_mask=head_mask,
  1007. inputs_embeds=inputs_embeds,
  1008. output_attentions=output_attentions,
  1009. output_hidden_states=output_hidden_states,
  1010. return_dict=return_dict,
  1011. )
  1012. sequence_output = outputs[0]
  1013. sequence_output = self.dropout(sequence_output)
  1014. logits = self.classifier(sequence_output)
  1015. loss = None
  1016. if labels is not None:
  1017. loss_fct = CrossEntropyLoss()
  1018. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1019. if not return_dict:
  1020. output = (logits,) + outputs[2:]
  1021. return ((loss,) + output) if loss is not None else output
  1022. return TokenClassifierOutput(
  1023. loss=loss,
  1024. logits=logits,
  1025. hidden_states=outputs.hidden_states,
  1026. attentions=outputs.attentions,
  1027. )
  1028. @add_start_docstrings(
  1029. """
  1030. LayoutLM Model with a span classification head on top for extractive question-answering tasks such as
  1031. [DocVQA](https://rrc.cvc.uab.es/?ch=17) (a linear layer on top of the final hidden-states output to compute `span
  1032. start logits` and `span end logits`).
  1033. """,
  1034. LAYOUTLM_START_DOCSTRING,
  1035. )
  1036. class LayoutLMForQuestionAnswering(LayoutLMPreTrainedModel):
  1037. def __init__(self, config, has_visual_segment_embedding=True):
  1038. super().__init__(config)
  1039. self.num_labels = config.num_labels
  1040. self.layoutlm = LayoutLMModel(config)
  1041. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  1042. # Initialize weights and apply final processing
  1043. self.post_init()
  1044. def get_input_embeddings(self):
  1045. return self.layoutlm.embeddings.word_embeddings
  1046. @replace_return_docstrings(output_type=QuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC)
  1047. def forward(
  1048. self,
  1049. input_ids: Optional[torch.LongTensor] = None,
  1050. bbox: Optional[torch.LongTensor] = None,
  1051. attention_mask: Optional[torch.FloatTensor] = None,
  1052. token_type_ids: Optional[torch.LongTensor] = None,
  1053. position_ids: Optional[torch.LongTensor] = None,
  1054. head_mask: Optional[torch.FloatTensor] = None,
  1055. inputs_embeds: Optional[torch.FloatTensor] = None,
  1056. start_positions: Optional[torch.LongTensor] = None,
  1057. end_positions: Optional[torch.LongTensor] = None,
  1058. output_attentions: Optional[bool] = None,
  1059. output_hidden_states: Optional[bool] = None,
  1060. return_dict: Optional[bool] = None,
  1061. ) -> Union[Tuple, QuestionAnsweringModelOutput]:
  1062. r"""
  1063. start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1064. Labels for position (index) of the start of the labelled span for computing the token classification loss.
  1065. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
  1066. are not taken into account for computing the loss.
  1067. end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1068. Labels for position (index) of the end of the labelled span for computing the token classification loss.
  1069. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
  1070. are not taken into account for computing the loss.
  1071. Returns:
  1072. Example:
  1073. In the example below, we prepare a question + context pair for the LayoutLM model. It will give us a prediction
  1074. of what it thinks the answer is (the span of the answer within the texts parsed from the image).
  1075. ```python
  1076. >>> from transformers import AutoTokenizer, LayoutLMForQuestionAnswering
  1077. >>> from datasets import load_dataset
  1078. >>> import torch
  1079. >>> tokenizer = AutoTokenizer.from_pretrained("impira/layoutlm-document-qa", add_prefix_space=True)
  1080. >>> model = LayoutLMForQuestionAnswering.from_pretrained("impira/layoutlm-document-qa", revision="1e3ebac")
  1081. >>> dataset = load_dataset("nielsr/funsd", split="train", trust_remote_code=True)
  1082. >>> example = dataset[0]
  1083. >>> question = "what's his name?"
  1084. >>> words = example["words"]
  1085. >>> boxes = example["bboxes"]
  1086. >>> encoding = tokenizer(
  1087. ... question.split(), words, is_split_into_words=True, return_token_type_ids=True, return_tensors="pt"
  1088. ... )
  1089. >>> bbox = []
  1090. >>> for i, s, w in zip(encoding.input_ids[0], encoding.sequence_ids(0), encoding.word_ids(0)):
  1091. ... if s == 1:
  1092. ... bbox.append(boxes[w])
  1093. ... elif i == tokenizer.sep_token_id:
  1094. ... bbox.append([1000] * 4)
  1095. ... else:
  1096. ... bbox.append([0] * 4)
  1097. >>> encoding["bbox"] = torch.tensor([bbox])
  1098. >>> word_ids = encoding.word_ids(0)
  1099. >>> outputs = model(**encoding)
  1100. >>> loss = outputs.loss
  1101. >>> start_scores = outputs.start_logits
  1102. >>> end_scores = outputs.end_logits
  1103. >>> start, end = word_ids[start_scores.argmax(-1)], word_ids[end_scores.argmax(-1)]
  1104. >>> print(" ".join(words[start : end + 1]))
  1105. M. Hamann P. Harper, P. Martinez
  1106. ```"""
  1107. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1108. outputs = self.layoutlm(
  1109. input_ids=input_ids,
  1110. bbox=bbox,
  1111. attention_mask=attention_mask,
  1112. token_type_ids=token_type_ids,
  1113. position_ids=position_ids,
  1114. head_mask=head_mask,
  1115. inputs_embeds=inputs_embeds,
  1116. output_attentions=output_attentions,
  1117. output_hidden_states=output_hidden_states,
  1118. return_dict=return_dict,
  1119. )
  1120. sequence_output = outputs[0]
  1121. logits = self.qa_outputs(sequence_output)
  1122. start_logits, end_logits = logits.split(1, dim=-1)
  1123. start_logits = start_logits.squeeze(-1).contiguous()
  1124. end_logits = end_logits.squeeze(-1).contiguous()
  1125. total_loss = None
  1126. if start_positions is not None and end_positions is not None:
  1127. # If we are on multi-GPU, split add a dimension
  1128. if len(start_positions.size()) > 1:
  1129. start_positions = start_positions.squeeze(-1)
  1130. if len(end_positions.size()) > 1:
  1131. end_positions = end_positions.squeeze(-1)
  1132. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1133. ignored_index = start_logits.size(1)
  1134. start_positions = start_positions.clamp(0, ignored_index)
  1135. end_positions = end_positions.clamp(0, ignored_index)
  1136. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1137. start_loss = loss_fct(start_logits, start_positions)
  1138. end_loss = loss_fct(end_logits, end_positions)
  1139. total_loss = (start_loss + end_loss) / 2
  1140. if not return_dict:
  1141. output = (start_logits, end_logits) + outputs[2:]
  1142. return ((total_loss,) + output) if total_loss is not None else output
  1143. return QuestionAnsweringModelOutput(
  1144. loss=total_loss,
  1145. start_logits=start_logits,
  1146. end_logits=end_logits,
  1147. hidden_states=outputs.hidden_states,
  1148. attentions=outputs.attentions,
  1149. )