modeling_layoutlmv2.py 60 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417
  1. # coding=utf-8
  2. # Copyright 2021 Microsoft Research The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch LayoutLMv2 model."""
  16. import math
  17. from typing import Optional, Tuple, Union
  18. import torch
  19. import torch.utils.checkpoint
  20. from torch import nn
  21. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  22. from ...activations import ACT2FN
  23. from ...modeling_outputs import (
  24. BaseModelOutput,
  25. BaseModelOutputWithPooling,
  26. QuestionAnsweringModelOutput,
  27. SequenceClassifierOutput,
  28. TokenClassifierOutput,
  29. )
  30. from ...modeling_utils import PreTrainedModel
  31. from ...pytorch_utils import apply_chunking_to_forward
  32. from ...utils import (
  33. add_start_docstrings,
  34. add_start_docstrings_to_model_forward,
  35. is_detectron2_available,
  36. logging,
  37. replace_return_docstrings,
  38. requires_backends,
  39. )
  40. from .configuration_layoutlmv2 import LayoutLMv2Config
  41. # soft dependency
  42. if is_detectron2_available():
  43. import detectron2
  44. from detectron2.modeling import META_ARCH_REGISTRY
  45. logger = logging.get_logger(__name__)
  46. _CHECKPOINT_FOR_DOC = "microsoft/layoutlmv2-base-uncased"
  47. _CONFIG_FOR_DOC = "LayoutLMv2Config"
  48. class LayoutLMv2Embeddings(nn.Module):
  49. """Construct the embeddings from word, position and token_type embeddings."""
  50. def __init__(self, config):
  51. super(LayoutLMv2Embeddings, self).__init__()
  52. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  53. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  54. self.x_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.coordinate_size)
  55. self.y_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.coordinate_size)
  56. self.h_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.shape_size)
  57. self.w_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.shape_size)
  58. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  59. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  60. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  61. self.register_buffer(
  62. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  63. )
  64. def _calc_spatial_position_embeddings(self, bbox):
  65. try:
  66. left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0])
  67. upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1])
  68. right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2])
  69. lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3])
  70. except IndexError as e:
  71. raise IndexError("The `bbox` coordinate values should be within 0-1000 range.") from e
  72. h_position_embeddings = self.h_position_embeddings(bbox[:, :, 3] - bbox[:, :, 1])
  73. w_position_embeddings = self.w_position_embeddings(bbox[:, :, 2] - bbox[:, :, 0])
  74. spatial_position_embeddings = torch.cat(
  75. [
  76. left_position_embeddings,
  77. upper_position_embeddings,
  78. right_position_embeddings,
  79. lower_position_embeddings,
  80. h_position_embeddings,
  81. w_position_embeddings,
  82. ],
  83. dim=-1,
  84. )
  85. return spatial_position_embeddings
  86. class LayoutLMv2SelfAttention(nn.Module):
  87. def __init__(self, config):
  88. super().__init__()
  89. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  90. raise ValueError(
  91. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  92. f"heads ({config.num_attention_heads})"
  93. )
  94. self.fast_qkv = config.fast_qkv
  95. self.num_attention_heads = config.num_attention_heads
  96. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  97. self.all_head_size = self.num_attention_heads * self.attention_head_size
  98. self.has_relative_attention_bias = config.has_relative_attention_bias
  99. self.has_spatial_attention_bias = config.has_spatial_attention_bias
  100. if config.fast_qkv:
  101. self.qkv_linear = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=False)
  102. self.q_bias = nn.Parameter(torch.zeros(1, 1, self.all_head_size))
  103. self.v_bias = nn.Parameter(torch.zeros(1, 1, self.all_head_size))
  104. else:
  105. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  106. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  107. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  108. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  109. def transpose_for_scores(self, x):
  110. new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
  111. x = x.view(*new_x_shape)
  112. return x.permute(0, 2, 1, 3)
  113. def compute_qkv(self, hidden_states):
  114. if self.fast_qkv:
  115. qkv = self.qkv_linear(hidden_states)
  116. q, k, v = torch.chunk(qkv, 3, dim=-1)
  117. if q.ndimension() == self.q_bias.ndimension():
  118. q = q + self.q_bias
  119. v = v + self.v_bias
  120. else:
  121. _sz = (1,) * (q.ndimension() - 1) + (-1,)
  122. q = q + self.q_bias.view(*_sz)
  123. v = v + self.v_bias.view(*_sz)
  124. else:
  125. q = self.query(hidden_states)
  126. k = self.key(hidden_states)
  127. v = self.value(hidden_states)
  128. return q, k, v
  129. def forward(
  130. self,
  131. hidden_states,
  132. attention_mask=None,
  133. head_mask=None,
  134. output_attentions=False,
  135. rel_pos=None,
  136. rel_2d_pos=None,
  137. ):
  138. q, k, v = self.compute_qkv(hidden_states)
  139. # (B, L, H*D) -> (B, H, L, D)
  140. query_layer = self.transpose_for_scores(q)
  141. key_layer = self.transpose_for_scores(k)
  142. value_layer = self.transpose_for_scores(v)
  143. query_layer = query_layer / math.sqrt(self.attention_head_size)
  144. # [BSZ, NAT, L, L]
  145. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  146. if self.has_relative_attention_bias:
  147. attention_scores += rel_pos
  148. if self.has_spatial_attention_bias:
  149. attention_scores += rel_2d_pos
  150. attention_scores = attention_scores.float().masked_fill_(
  151. attention_mask.to(torch.bool), torch.finfo(attention_scores.dtype).min
  152. )
  153. attention_probs = nn.functional.softmax(attention_scores, dim=-1, dtype=torch.float32).type_as(value_layer)
  154. # This is actually dropping out entire tokens to attend to, which might
  155. # seem a bit unusual, but is taken from the original Transformer paper.
  156. attention_probs = self.dropout(attention_probs)
  157. # Mask heads if we want to
  158. if head_mask is not None:
  159. attention_probs = attention_probs * head_mask
  160. context_layer = torch.matmul(attention_probs, value_layer)
  161. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  162. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  163. context_layer = context_layer.view(*new_context_layer_shape)
  164. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  165. return outputs
  166. class LayoutLMv2Attention(nn.Module):
  167. def __init__(self, config):
  168. super().__init__()
  169. self.self = LayoutLMv2SelfAttention(config)
  170. self.output = LayoutLMv2SelfOutput(config)
  171. def forward(
  172. self,
  173. hidden_states,
  174. attention_mask=None,
  175. head_mask=None,
  176. output_attentions=False,
  177. rel_pos=None,
  178. rel_2d_pos=None,
  179. ):
  180. self_outputs = self.self(
  181. hidden_states,
  182. attention_mask,
  183. head_mask,
  184. output_attentions,
  185. rel_pos=rel_pos,
  186. rel_2d_pos=rel_2d_pos,
  187. )
  188. attention_output = self.output(self_outputs[0], hidden_states)
  189. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  190. return outputs
  191. class LayoutLMv2SelfOutput(nn.Module):
  192. def __init__(self, config):
  193. super().__init__()
  194. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  195. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  196. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  197. def forward(self, hidden_states, input_tensor):
  198. hidden_states = self.dense(hidden_states)
  199. hidden_states = self.dropout(hidden_states)
  200. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  201. return hidden_states
  202. # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->LayoutLMv2
  203. class LayoutLMv2Intermediate(nn.Module):
  204. def __init__(self, config):
  205. super().__init__()
  206. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  207. if isinstance(config.hidden_act, str):
  208. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  209. else:
  210. self.intermediate_act_fn = config.hidden_act
  211. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  212. hidden_states = self.dense(hidden_states)
  213. hidden_states = self.intermediate_act_fn(hidden_states)
  214. return hidden_states
  215. # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->LayoutLM
  216. class LayoutLMv2Output(nn.Module):
  217. def __init__(self, config):
  218. super().__init__()
  219. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  220. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  221. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  222. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  223. hidden_states = self.dense(hidden_states)
  224. hidden_states = self.dropout(hidden_states)
  225. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  226. return hidden_states
  227. class LayoutLMv2Layer(nn.Module):
  228. def __init__(self, config):
  229. super().__init__()
  230. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  231. self.seq_len_dim = 1
  232. self.attention = LayoutLMv2Attention(config)
  233. self.intermediate = LayoutLMv2Intermediate(config)
  234. self.output = LayoutLMv2Output(config)
  235. def forward(
  236. self,
  237. hidden_states,
  238. attention_mask=None,
  239. head_mask=None,
  240. output_attentions=False,
  241. rel_pos=None,
  242. rel_2d_pos=None,
  243. ):
  244. self_attention_outputs = self.attention(
  245. hidden_states,
  246. attention_mask,
  247. head_mask,
  248. output_attentions=output_attentions,
  249. rel_pos=rel_pos,
  250. rel_2d_pos=rel_2d_pos,
  251. )
  252. attention_output = self_attention_outputs[0]
  253. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  254. layer_output = apply_chunking_to_forward(
  255. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  256. )
  257. outputs = (layer_output,) + outputs
  258. return outputs
  259. def feed_forward_chunk(self, attention_output):
  260. intermediate_output = self.intermediate(attention_output)
  261. layer_output = self.output(intermediate_output, attention_output)
  262. return layer_output
  263. def relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
  264. """
  265. Adapted from Mesh Tensorflow:
  266. https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
  267. Translate relative position to a bucket number for relative attention. The relative position is defined as
  268. memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
  269. position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for small
  270. absolute relative_position and larger buckets for larger absolute relative_positions. All relative positions
  271. >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. This should
  272. allow for more graceful generalization to longer sequences than the model has been trained on.
  273. Args:
  274. relative_position: an int32 Tensor
  275. bidirectional: a boolean - whether the attention is bidirectional
  276. num_buckets: an integer
  277. max_distance: an integer
  278. Returns:
  279. a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
  280. """
  281. ret = 0
  282. if bidirectional:
  283. num_buckets //= 2
  284. ret += (relative_position > 0).long() * num_buckets
  285. n = torch.abs(relative_position)
  286. else:
  287. n = torch.max(-relative_position, torch.zeros_like(relative_position))
  288. # now n is in the range [0, inf)
  289. # half of the buckets are for exact increments in positions
  290. max_exact = num_buckets // 2
  291. is_small = n < max_exact
  292. # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
  293. val_if_large = max_exact + (
  294. torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
  295. ).to(torch.long)
  296. val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
  297. ret += torch.where(is_small, n, val_if_large)
  298. return ret
  299. class LayoutLMv2Encoder(nn.Module):
  300. def __init__(self, config):
  301. super().__init__()
  302. self.config = config
  303. self.layer = nn.ModuleList([LayoutLMv2Layer(config) for _ in range(config.num_hidden_layers)])
  304. self.has_relative_attention_bias = config.has_relative_attention_bias
  305. self.has_spatial_attention_bias = config.has_spatial_attention_bias
  306. if self.has_relative_attention_bias:
  307. self.rel_pos_bins = config.rel_pos_bins
  308. self.max_rel_pos = config.max_rel_pos
  309. self.rel_pos_bias = nn.Linear(self.rel_pos_bins, config.num_attention_heads, bias=False)
  310. if self.has_spatial_attention_bias:
  311. self.max_rel_2d_pos = config.max_rel_2d_pos
  312. self.rel_2d_pos_bins = config.rel_2d_pos_bins
  313. self.rel_pos_x_bias = nn.Linear(self.rel_2d_pos_bins, config.num_attention_heads, bias=False)
  314. self.rel_pos_y_bias = nn.Linear(self.rel_2d_pos_bins, config.num_attention_heads, bias=False)
  315. self.gradient_checkpointing = False
  316. def _calculate_1d_position_embeddings(self, position_ids):
  317. rel_pos_mat = position_ids.unsqueeze(-2) - position_ids.unsqueeze(-1)
  318. rel_pos = relative_position_bucket(
  319. rel_pos_mat,
  320. num_buckets=self.rel_pos_bins,
  321. max_distance=self.max_rel_pos,
  322. )
  323. # Since this is a simple indexing operation that is independent of the input,
  324. # no need to track gradients for this operation
  325. #
  326. # Without this no_grad context, training speed slows down significantly
  327. with torch.no_grad():
  328. rel_pos = self.rel_pos_bias.weight.t()[rel_pos].permute(0, 3, 1, 2)
  329. rel_pos = rel_pos.contiguous()
  330. return rel_pos
  331. def _calculate_2d_position_embeddings(self, bbox):
  332. position_coord_x = bbox[:, :, 0]
  333. position_coord_y = bbox[:, :, 3]
  334. rel_pos_x_2d_mat = position_coord_x.unsqueeze(-2) - position_coord_x.unsqueeze(-1)
  335. rel_pos_y_2d_mat = position_coord_y.unsqueeze(-2) - position_coord_y.unsqueeze(-1)
  336. rel_pos_x = relative_position_bucket(
  337. rel_pos_x_2d_mat,
  338. num_buckets=self.rel_2d_pos_bins,
  339. max_distance=self.max_rel_2d_pos,
  340. )
  341. rel_pos_y = relative_position_bucket(
  342. rel_pos_y_2d_mat,
  343. num_buckets=self.rel_2d_pos_bins,
  344. max_distance=self.max_rel_2d_pos,
  345. )
  346. # Since this is a simple indexing operation that is independent of the input,
  347. # no need to track gradients for this operation
  348. #
  349. # Without this no_grad context, training speed slows down significantly
  350. with torch.no_grad():
  351. rel_pos_x = self.rel_pos_x_bias.weight.t()[rel_pos_x].permute(0, 3, 1, 2)
  352. rel_pos_y = self.rel_pos_y_bias.weight.t()[rel_pos_y].permute(0, 3, 1, 2)
  353. rel_pos_x = rel_pos_x.contiguous()
  354. rel_pos_y = rel_pos_y.contiguous()
  355. rel_2d_pos = rel_pos_x + rel_pos_y
  356. return rel_2d_pos
  357. def forward(
  358. self,
  359. hidden_states,
  360. attention_mask=None,
  361. head_mask=None,
  362. output_attentions=False,
  363. output_hidden_states=False,
  364. return_dict=True,
  365. bbox=None,
  366. position_ids=None,
  367. ):
  368. all_hidden_states = () if output_hidden_states else None
  369. all_self_attentions = () if output_attentions else None
  370. rel_pos = self._calculate_1d_position_embeddings(position_ids) if self.has_relative_attention_bias else None
  371. rel_2d_pos = self._calculate_2d_position_embeddings(bbox) if self.has_spatial_attention_bias else None
  372. for i, layer_module in enumerate(self.layer):
  373. if output_hidden_states:
  374. all_hidden_states = all_hidden_states + (hidden_states,)
  375. layer_head_mask = head_mask[i] if head_mask is not None else None
  376. if self.gradient_checkpointing and self.training:
  377. layer_outputs = self._gradient_checkpointing_func(
  378. layer_module.__call__,
  379. hidden_states,
  380. attention_mask,
  381. layer_head_mask,
  382. output_attentions,
  383. rel_pos=rel_pos,
  384. rel_2d_pos=rel_2d_pos,
  385. )
  386. else:
  387. layer_outputs = layer_module(
  388. hidden_states,
  389. attention_mask,
  390. layer_head_mask,
  391. output_attentions,
  392. rel_pos=rel_pos,
  393. rel_2d_pos=rel_2d_pos,
  394. )
  395. hidden_states = layer_outputs[0]
  396. if output_attentions:
  397. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  398. if output_hidden_states:
  399. all_hidden_states = all_hidden_states + (hidden_states,)
  400. if not return_dict:
  401. return tuple(
  402. v
  403. for v in [
  404. hidden_states,
  405. all_hidden_states,
  406. all_self_attentions,
  407. ]
  408. if v is not None
  409. )
  410. return BaseModelOutput(
  411. last_hidden_state=hidden_states,
  412. hidden_states=all_hidden_states,
  413. attentions=all_self_attentions,
  414. )
  415. class LayoutLMv2PreTrainedModel(PreTrainedModel):
  416. """
  417. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  418. models.
  419. """
  420. config_class = LayoutLMv2Config
  421. base_model_prefix = "layoutlmv2"
  422. def _init_weights(self, module):
  423. """Initialize the weights"""
  424. if isinstance(module, nn.Linear):
  425. # Slightly different from the TF version which uses truncated_normal for initialization
  426. # cf https://github.com/pytorch/pytorch/pull/5617
  427. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  428. if module.bias is not None:
  429. module.bias.data.zero_()
  430. elif isinstance(module, nn.Embedding):
  431. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  432. if module.padding_idx is not None:
  433. module.weight.data[module.padding_idx].zero_()
  434. elif isinstance(module, nn.LayerNorm):
  435. module.bias.data.zero_()
  436. module.weight.data.fill_(1.0)
  437. elif isinstance(module, LayoutLMv2Model):
  438. if hasattr(module, "visual_segment_embedding"):
  439. module.visual_segment_embedding.data.normal_(mean=0.0, std=self.config.initializer_range)
  440. def my_convert_sync_batchnorm(module, process_group=None):
  441. # same as `nn.modules.SyncBatchNorm.convert_sync_batchnorm` but allowing converting from `detectron2.layers.FrozenBatchNorm2d`
  442. if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
  443. return nn.modules.SyncBatchNorm.convert_sync_batchnorm(module, process_group)
  444. module_output = module
  445. if isinstance(module, detectron2.layers.FrozenBatchNorm2d):
  446. module_output = torch.nn.SyncBatchNorm(
  447. num_features=module.num_features,
  448. eps=module.eps,
  449. affine=True,
  450. track_running_stats=True,
  451. process_group=process_group,
  452. )
  453. module_output.weight = torch.nn.Parameter(module.weight)
  454. module_output.bias = torch.nn.Parameter(module.bias)
  455. module_output.running_mean = module.running_mean
  456. module_output.running_var = module.running_var
  457. module_output.num_batches_tracked = torch.tensor(0, dtype=torch.long, device=module.running_mean.device)
  458. for name, child in module.named_children():
  459. module_output.add_module(name, my_convert_sync_batchnorm(child, process_group))
  460. del module
  461. return module_output
  462. class LayoutLMv2VisualBackbone(nn.Module):
  463. def __init__(self, config):
  464. super().__init__()
  465. self.cfg = config.get_detectron2_config()
  466. meta_arch = self.cfg.MODEL.META_ARCHITECTURE
  467. model = META_ARCH_REGISTRY.get(meta_arch)(self.cfg)
  468. assert isinstance(model.backbone, detectron2.modeling.backbone.FPN)
  469. self.backbone = model.backbone
  470. assert len(self.cfg.MODEL.PIXEL_MEAN) == len(self.cfg.MODEL.PIXEL_STD)
  471. num_channels = len(self.cfg.MODEL.PIXEL_MEAN)
  472. self.register_buffer(
  473. "pixel_mean",
  474. torch.Tensor(self.cfg.MODEL.PIXEL_MEAN).view(num_channels, 1, 1),
  475. persistent=False,
  476. )
  477. self.register_buffer(
  478. "pixel_std", torch.Tensor(self.cfg.MODEL.PIXEL_STD).view(num_channels, 1, 1), persistent=False
  479. )
  480. self.out_feature_key = "p2"
  481. if torch.are_deterministic_algorithms_enabled():
  482. logger.warning("using `AvgPool2d` instead of `AdaptiveAvgPool2d`")
  483. input_shape = (224, 224)
  484. backbone_stride = self.backbone.output_shape()[self.out_feature_key].stride
  485. self.pool = nn.AvgPool2d(
  486. (
  487. math.ceil(math.ceil(input_shape[0] / backbone_stride) / config.image_feature_pool_shape[0]),
  488. math.ceil(math.ceil(input_shape[1] / backbone_stride) / config.image_feature_pool_shape[1]),
  489. )
  490. )
  491. else:
  492. self.pool = nn.AdaptiveAvgPool2d(config.image_feature_pool_shape[:2])
  493. if len(config.image_feature_pool_shape) == 2:
  494. config.image_feature_pool_shape.append(self.backbone.output_shape()[self.out_feature_key].channels)
  495. assert self.backbone.output_shape()[self.out_feature_key].channels == config.image_feature_pool_shape[2]
  496. def forward(self, images):
  497. images_input = ((images if torch.is_tensor(images) else images.tensor) - self.pixel_mean) / self.pixel_std
  498. features = self.backbone(images_input)
  499. features = features[self.out_feature_key]
  500. features = self.pool(features).flatten(start_dim=2).transpose(1, 2).contiguous()
  501. return features
  502. def synchronize_batch_norm(self):
  503. if not (
  504. torch.distributed.is_available()
  505. and torch.distributed.is_initialized()
  506. and torch.distributed.get_rank() > -1
  507. ):
  508. raise RuntimeError("Make sure torch.distributed is set up properly.")
  509. self_rank = torch.distributed.get_rank()
  510. node_size = torch.cuda.device_count()
  511. world_size = torch.distributed.get_world_size()
  512. if not (world_size % node_size == 0):
  513. raise RuntimeError("Make sure the number of processes can be divided by the number of nodes")
  514. node_global_ranks = [list(range(i * node_size, (i + 1) * node_size)) for i in range(world_size // node_size)]
  515. sync_bn_groups = [
  516. torch.distributed.new_group(ranks=node_global_ranks[i]) for i in range(world_size // node_size)
  517. ]
  518. node_rank = self_rank // node_size
  519. self.backbone = my_convert_sync_batchnorm(self.backbone, process_group=sync_bn_groups[node_rank])
  520. LAYOUTLMV2_START_DOCSTRING = r"""
  521. This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
  522. it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
  523. behavior.
  524. Parameters:
  525. config ([`LayoutLMv2Config`]): Model configuration class with all the parameters of the model.
  526. Initializing with a config file does not load the weights associated with the model, only the
  527. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  528. """
  529. LAYOUTLMV2_INPUTS_DOCSTRING = r"""
  530. Args:
  531. input_ids (`torch.LongTensor` of shape `{0}`):
  532. Indices of input sequence tokens in the vocabulary.
  533. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  534. [`PreTrainedTokenizer.__call__`] for details.
  535. [What are input IDs?](../glossary#input-ids)
  536. bbox (`torch.LongTensor` of shape `({0}, 4)`, *optional*):
  537. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  538. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  539. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  540. y1) represents the position of the lower right corner.
  541. image (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `detectron.structures.ImageList` whose `tensors` is of shape `(batch_size, num_channels, height, width)`):
  542. Batch of document images.
  543. attention_mask (`torch.FloatTensor` of shape `{0}`, *optional*):
  544. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  545. - 1 for tokens that are **not masked**,
  546. - 0 for tokens that are **masked**.
  547. [What are attention masks?](../glossary#attention-mask)
  548. token_type_ids (`torch.LongTensor` of shape `{0}`, *optional*):
  549. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  550. 1]`:
  551. - 0 corresponds to a *sentence A* token,
  552. - 1 corresponds to a *sentence B* token.
  553. [What are token type IDs?](../glossary#token-type-ids)
  554. position_ids (`torch.LongTensor` of shape `{0}`, *optional*):
  555. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  556. config.max_position_embeddings - 1]`.
  557. [What are position IDs?](../glossary#position-ids)
  558. head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  559. Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
  560. - 1 indicates the head is **not masked**,
  561. - 0 indicates the head is **masked**.
  562. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  563. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  564. is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
  565. model's internal embedding lookup matrix.
  566. output_attentions (`bool`, *optional*):
  567. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  568. tensors for more detail.
  569. output_hidden_states (`bool`, *optional*):
  570. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  571. more detail.
  572. return_dict (`bool`, *optional*):
  573. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  574. """
  575. class LayoutLMv2Pooler(nn.Module):
  576. def __init__(self, config):
  577. super().__init__()
  578. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  579. self.activation = nn.Tanh()
  580. def forward(self, hidden_states):
  581. # We "pool" the model by simply taking the hidden state corresponding
  582. # to the first token.
  583. first_token_tensor = hidden_states[:, 0]
  584. pooled_output = self.dense(first_token_tensor)
  585. pooled_output = self.activation(pooled_output)
  586. return pooled_output
  587. @add_start_docstrings(
  588. "The bare LayoutLMv2 Model transformer outputting raw hidden-states without any specific head on top.",
  589. LAYOUTLMV2_START_DOCSTRING,
  590. )
  591. class LayoutLMv2Model(LayoutLMv2PreTrainedModel):
  592. def __init__(self, config):
  593. requires_backends(self, "detectron2")
  594. super().__init__(config)
  595. self.config = config
  596. self.has_visual_segment_embedding = config.has_visual_segment_embedding
  597. self.embeddings = LayoutLMv2Embeddings(config)
  598. self.visual = LayoutLMv2VisualBackbone(config)
  599. self.visual_proj = nn.Linear(config.image_feature_pool_shape[-1], config.hidden_size)
  600. if self.has_visual_segment_embedding:
  601. self.visual_segment_embedding = nn.Parameter(nn.Embedding(1, config.hidden_size).weight[0])
  602. self.visual_LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  603. self.visual_dropout = nn.Dropout(config.hidden_dropout_prob)
  604. self.encoder = LayoutLMv2Encoder(config)
  605. self.pooler = LayoutLMv2Pooler(config)
  606. # Initialize weights and apply final processing
  607. self.post_init()
  608. def get_input_embeddings(self):
  609. return self.embeddings.word_embeddings
  610. def set_input_embeddings(self, value):
  611. self.embeddings.word_embeddings = value
  612. def _calc_text_embeddings(self, input_ids, bbox, position_ids, token_type_ids, inputs_embeds=None):
  613. if input_ids is not None:
  614. input_shape = input_ids.size()
  615. else:
  616. input_shape = inputs_embeds.size()[:-1]
  617. seq_length = input_shape[1]
  618. if position_ids is None:
  619. position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
  620. position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
  621. if token_type_ids is None:
  622. token_type_ids = torch.zeros_like(input_ids)
  623. if inputs_embeds is None:
  624. inputs_embeds = self.embeddings.word_embeddings(input_ids)
  625. position_embeddings = self.embeddings.position_embeddings(position_ids)
  626. spatial_position_embeddings = self.embeddings._calc_spatial_position_embeddings(bbox)
  627. token_type_embeddings = self.embeddings.token_type_embeddings(token_type_ids)
  628. embeddings = inputs_embeds + position_embeddings + spatial_position_embeddings + token_type_embeddings
  629. embeddings = self.embeddings.LayerNorm(embeddings)
  630. embeddings = self.embeddings.dropout(embeddings)
  631. return embeddings
  632. def _calc_img_embeddings(self, image, bbox, position_ids):
  633. visual_embeddings = self.visual_proj(self.visual(image))
  634. position_embeddings = self.embeddings.position_embeddings(position_ids)
  635. spatial_position_embeddings = self.embeddings._calc_spatial_position_embeddings(bbox)
  636. embeddings = visual_embeddings + position_embeddings + spatial_position_embeddings
  637. if self.has_visual_segment_embedding:
  638. embeddings += self.visual_segment_embedding
  639. embeddings = self.visual_LayerNorm(embeddings)
  640. embeddings = self.visual_dropout(embeddings)
  641. return embeddings
  642. def _calc_visual_bbox(self, image_feature_pool_shape, bbox, device, final_shape):
  643. visual_bbox_x = torch.div(
  644. torch.arange(
  645. 0,
  646. 1000 * (image_feature_pool_shape[1] + 1),
  647. 1000,
  648. device=device,
  649. dtype=bbox.dtype,
  650. ),
  651. self.config.image_feature_pool_shape[1],
  652. rounding_mode="floor",
  653. )
  654. visual_bbox_y = torch.div(
  655. torch.arange(
  656. 0,
  657. 1000 * (self.config.image_feature_pool_shape[0] + 1),
  658. 1000,
  659. device=device,
  660. dtype=bbox.dtype,
  661. ),
  662. self.config.image_feature_pool_shape[0],
  663. rounding_mode="floor",
  664. )
  665. visual_bbox = torch.stack(
  666. [
  667. visual_bbox_x[:-1].repeat(image_feature_pool_shape[0], 1),
  668. visual_bbox_y[:-1].repeat(image_feature_pool_shape[1], 1).transpose(0, 1),
  669. visual_bbox_x[1:].repeat(image_feature_pool_shape[0], 1),
  670. visual_bbox_y[1:].repeat(image_feature_pool_shape[1], 1).transpose(0, 1),
  671. ],
  672. dim=-1,
  673. ).view(-1, bbox.size(-1))
  674. visual_bbox = visual_bbox.repeat(final_shape[0], 1, 1)
  675. return visual_bbox
  676. def _get_input_shape(self, input_ids=None, inputs_embeds=None):
  677. if input_ids is not None and inputs_embeds is not None:
  678. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  679. elif input_ids is not None:
  680. return input_ids.size()
  681. elif inputs_embeds is not None:
  682. return inputs_embeds.size()[:-1]
  683. else:
  684. raise ValueError("You have to specify either input_ids or inputs_embeds")
  685. @add_start_docstrings_to_model_forward(LAYOUTLMV2_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
  686. @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
  687. def forward(
  688. self,
  689. input_ids: Optional[torch.LongTensor] = None,
  690. bbox: Optional[torch.LongTensor] = None,
  691. image: Optional[torch.FloatTensor] = None,
  692. attention_mask: Optional[torch.FloatTensor] = None,
  693. token_type_ids: Optional[torch.LongTensor] = None,
  694. position_ids: Optional[torch.LongTensor] = None,
  695. head_mask: Optional[torch.FloatTensor] = None,
  696. inputs_embeds: Optional[torch.FloatTensor] = None,
  697. output_attentions: Optional[bool] = None,
  698. output_hidden_states: Optional[bool] = None,
  699. return_dict: Optional[bool] = None,
  700. ) -> Union[Tuple, BaseModelOutputWithPooling]:
  701. r"""
  702. Return:
  703. Examples:
  704. ```python
  705. >>> from transformers import AutoProcessor, LayoutLMv2Model, set_seed
  706. >>> from PIL import Image
  707. >>> import torch
  708. >>> from datasets import load_dataset
  709. >>> set_seed(0)
  710. >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv2-base-uncased")
  711. >>> model = LayoutLMv2Model.from_pretrained("microsoft/layoutlmv2-base-uncased")
  712. >>> dataset = load_dataset("hf-internal-testing/fixtures_docvqa", trust_remote_code=True)
  713. >>> image_path = dataset["test"][0]["file"]
  714. >>> image = Image.open(image_path).convert("RGB")
  715. >>> encoding = processor(image, return_tensors="pt")
  716. >>> outputs = model(**encoding)
  717. >>> last_hidden_states = outputs.last_hidden_state
  718. >>> last_hidden_states.shape
  719. torch.Size([1, 342, 768])
  720. ```
  721. """
  722. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  723. output_hidden_states = (
  724. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  725. )
  726. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  727. input_shape = self._get_input_shape(input_ids, inputs_embeds)
  728. device = input_ids.device if input_ids is not None else inputs_embeds.device
  729. visual_shape = list(input_shape)
  730. visual_shape[1] = self.config.image_feature_pool_shape[0] * self.config.image_feature_pool_shape[1]
  731. visual_shape = torch.Size(visual_shape)
  732. # needs a new copy of input_shape for tracing. Otherwise wrong dimensions will occur
  733. final_shape = list(self._get_input_shape(input_ids, inputs_embeds))
  734. final_shape[1] += visual_shape[1]
  735. final_shape = torch.Size(final_shape)
  736. visual_bbox = self._calc_visual_bbox(self.config.image_feature_pool_shape, bbox, device, final_shape)
  737. final_bbox = torch.cat([bbox, visual_bbox], dim=1)
  738. if attention_mask is None:
  739. attention_mask = torch.ones(input_shape, device=device)
  740. visual_attention_mask = torch.ones(visual_shape, device=device)
  741. final_attention_mask = torch.cat([attention_mask, visual_attention_mask], dim=1)
  742. if token_type_ids is None:
  743. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  744. if position_ids is None:
  745. seq_length = input_shape[1]
  746. position_ids = self.embeddings.position_ids[:, :seq_length]
  747. position_ids = position_ids.expand(input_shape)
  748. visual_position_ids = torch.arange(0, visual_shape[1], dtype=torch.long, device=device).repeat(
  749. input_shape[0], 1
  750. )
  751. final_position_ids = torch.cat([position_ids, visual_position_ids], dim=1)
  752. if bbox is None:
  753. bbox = torch.zeros(tuple(list(input_shape) + [4]), dtype=torch.long, device=device)
  754. text_layout_emb = self._calc_text_embeddings(
  755. input_ids=input_ids,
  756. bbox=bbox,
  757. token_type_ids=token_type_ids,
  758. position_ids=position_ids,
  759. inputs_embeds=inputs_embeds,
  760. )
  761. visual_emb = self._calc_img_embeddings(
  762. image=image,
  763. bbox=visual_bbox,
  764. position_ids=visual_position_ids,
  765. )
  766. final_emb = torch.cat([text_layout_emb, visual_emb], dim=1)
  767. extended_attention_mask = final_attention_mask.unsqueeze(1).unsqueeze(2)
  768. extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)
  769. extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min
  770. if head_mask is not None:
  771. if head_mask.dim() == 1:
  772. head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
  773. head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
  774. elif head_mask.dim() == 2:
  775. head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
  776. head_mask = head_mask.to(dtype=next(self.parameters()).dtype)
  777. else:
  778. head_mask = [None] * self.config.num_hidden_layers
  779. encoder_outputs = self.encoder(
  780. final_emb,
  781. extended_attention_mask,
  782. bbox=final_bbox,
  783. position_ids=final_position_ids,
  784. head_mask=head_mask,
  785. output_attentions=output_attentions,
  786. output_hidden_states=output_hidden_states,
  787. return_dict=return_dict,
  788. )
  789. sequence_output = encoder_outputs[0]
  790. pooled_output = self.pooler(sequence_output)
  791. if not return_dict:
  792. return (sequence_output, pooled_output) + encoder_outputs[1:]
  793. return BaseModelOutputWithPooling(
  794. last_hidden_state=sequence_output,
  795. pooler_output=pooled_output,
  796. hidden_states=encoder_outputs.hidden_states,
  797. attentions=encoder_outputs.attentions,
  798. )
  799. @add_start_docstrings(
  800. """
  801. LayoutLMv2 Model with a sequence classification head on top (a linear layer on top of the concatenation of the
  802. final hidden state of the [CLS] token, average-pooled initial visual embeddings and average-pooled final visual
  803. embeddings, e.g. for document image classification tasks such as the
  804. [RVL-CDIP](https://www.cs.cmu.edu/~aharley/rvl-cdip/) dataset.
  805. """,
  806. LAYOUTLMV2_START_DOCSTRING,
  807. )
  808. class LayoutLMv2ForSequenceClassification(LayoutLMv2PreTrainedModel):
  809. def __init__(self, config):
  810. super().__init__(config)
  811. self.num_labels = config.num_labels
  812. self.layoutlmv2 = LayoutLMv2Model(config)
  813. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  814. self.classifier = nn.Linear(config.hidden_size * 3, config.num_labels)
  815. # Initialize weights and apply final processing
  816. self.post_init()
  817. def get_input_embeddings(self):
  818. return self.layoutlmv2.embeddings.word_embeddings
  819. @add_start_docstrings_to_model_forward(LAYOUTLMV2_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  820. @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
  821. def forward(
  822. self,
  823. input_ids: Optional[torch.LongTensor] = None,
  824. bbox: Optional[torch.LongTensor] = None,
  825. image: Optional[torch.FloatTensor] = None,
  826. attention_mask: Optional[torch.FloatTensor] = None,
  827. token_type_ids: Optional[torch.LongTensor] = None,
  828. position_ids: Optional[torch.LongTensor] = None,
  829. head_mask: Optional[torch.FloatTensor] = None,
  830. inputs_embeds: Optional[torch.FloatTensor] = None,
  831. labels: Optional[torch.LongTensor] = None,
  832. output_attentions: Optional[bool] = None,
  833. output_hidden_states: Optional[bool] = None,
  834. return_dict: Optional[bool] = None,
  835. ) -> Union[Tuple, SequenceClassifierOutput]:
  836. r"""
  837. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  838. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  839. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  840. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  841. Returns:
  842. Example:
  843. ```python
  844. >>> from transformers import AutoProcessor, LayoutLMv2ForSequenceClassification, set_seed
  845. >>> from PIL import Image
  846. >>> import torch
  847. >>> from datasets import load_dataset
  848. >>> set_seed(0)
  849. >>> dataset = load_dataset("aharley/rvl_cdip", split="train", streaming=True, trust_remote_code=True)
  850. >>> data = next(iter(dataset))
  851. >>> image = data["image"].convert("RGB")
  852. >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv2-base-uncased")
  853. >>> model = LayoutLMv2ForSequenceClassification.from_pretrained(
  854. ... "microsoft/layoutlmv2-base-uncased", num_labels=dataset.info.features["label"].num_classes
  855. ... )
  856. >>> encoding = processor(image, return_tensors="pt")
  857. >>> sequence_label = torch.tensor([data["label"]])
  858. >>> outputs = model(**encoding, labels=sequence_label)
  859. >>> loss, logits = outputs.loss, outputs.logits
  860. >>> predicted_idx = logits.argmax(dim=-1).item()
  861. >>> predicted_answer = dataset.info.features["label"].names[4]
  862. >>> predicted_idx, predicted_answer # results are not good without further fine-tuning
  863. (7, 'advertisement')
  864. ```
  865. """
  866. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  867. if input_ids is not None and inputs_embeds is not None:
  868. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  869. elif input_ids is not None:
  870. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  871. input_shape = input_ids.size()
  872. elif inputs_embeds is not None:
  873. input_shape = inputs_embeds.size()[:-1]
  874. else:
  875. raise ValueError("You have to specify either input_ids or inputs_embeds")
  876. device = input_ids.device if input_ids is not None else inputs_embeds.device
  877. visual_shape = list(input_shape)
  878. visual_shape[1] = self.config.image_feature_pool_shape[0] * self.config.image_feature_pool_shape[1]
  879. visual_shape = torch.Size(visual_shape)
  880. final_shape = list(input_shape)
  881. final_shape[1] += visual_shape[1]
  882. final_shape = torch.Size(final_shape)
  883. visual_bbox = self.layoutlmv2._calc_visual_bbox(
  884. self.config.image_feature_pool_shape, bbox, device, final_shape
  885. )
  886. visual_position_ids = torch.arange(0, visual_shape[1], dtype=torch.long, device=device).repeat(
  887. input_shape[0], 1
  888. )
  889. initial_image_embeddings = self.layoutlmv2._calc_img_embeddings(
  890. image=image,
  891. bbox=visual_bbox,
  892. position_ids=visual_position_ids,
  893. )
  894. outputs = self.layoutlmv2(
  895. input_ids=input_ids,
  896. bbox=bbox,
  897. image=image,
  898. attention_mask=attention_mask,
  899. token_type_ids=token_type_ids,
  900. position_ids=position_ids,
  901. head_mask=head_mask,
  902. inputs_embeds=inputs_embeds,
  903. output_attentions=output_attentions,
  904. output_hidden_states=output_hidden_states,
  905. return_dict=return_dict,
  906. )
  907. if input_ids is not None:
  908. input_shape = input_ids.size()
  909. else:
  910. input_shape = inputs_embeds.size()[:-1]
  911. seq_length = input_shape[1]
  912. sequence_output, final_image_embeddings = outputs[0][:, :seq_length], outputs[0][:, seq_length:]
  913. cls_final_output = sequence_output[:, 0, :]
  914. # average-pool the visual embeddings
  915. pooled_initial_image_embeddings = initial_image_embeddings.mean(dim=1)
  916. pooled_final_image_embeddings = final_image_embeddings.mean(dim=1)
  917. # concatenate with cls_final_output
  918. sequence_output = torch.cat(
  919. [cls_final_output, pooled_initial_image_embeddings, pooled_final_image_embeddings], dim=1
  920. )
  921. sequence_output = self.dropout(sequence_output)
  922. logits = self.classifier(sequence_output)
  923. loss = None
  924. if labels is not None:
  925. if self.config.problem_type is None:
  926. if self.num_labels == 1:
  927. self.config.problem_type = "regression"
  928. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  929. self.config.problem_type = "single_label_classification"
  930. else:
  931. self.config.problem_type = "multi_label_classification"
  932. if self.config.problem_type == "regression":
  933. loss_fct = MSELoss()
  934. if self.num_labels == 1:
  935. loss = loss_fct(logits.squeeze(), labels.squeeze())
  936. else:
  937. loss = loss_fct(logits, labels)
  938. elif self.config.problem_type == "single_label_classification":
  939. loss_fct = CrossEntropyLoss()
  940. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  941. elif self.config.problem_type == "multi_label_classification":
  942. loss_fct = BCEWithLogitsLoss()
  943. loss = loss_fct(logits, labels)
  944. if not return_dict:
  945. output = (logits,) + outputs[2:]
  946. return ((loss,) + output) if loss is not None else output
  947. return SequenceClassifierOutput(
  948. loss=loss,
  949. logits=logits,
  950. hidden_states=outputs.hidden_states,
  951. attentions=outputs.attentions,
  952. )
  953. @add_start_docstrings(
  954. """
  955. LayoutLMv2 Model with a token classification head on top (a linear layer on top of the text part of the hidden
  956. states) e.g. for sequence labeling (information extraction) tasks such as
  957. [FUNSD](https://guillaumejaume.github.io/FUNSD/), [SROIE](https://rrc.cvc.uab.es/?ch=13),
  958. [CORD](https://github.com/clovaai/cord) and [Kleister-NDA](https://github.com/applicaai/kleister-nda).
  959. """,
  960. LAYOUTLMV2_START_DOCSTRING,
  961. )
  962. class LayoutLMv2ForTokenClassification(LayoutLMv2PreTrainedModel):
  963. def __init__(self, config):
  964. super().__init__(config)
  965. self.num_labels = config.num_labels
  966. self.layoutlmv2 = LayoutLMv2Model(config)
  967. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  968. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  969. # Initialize weights and apply final processing
  970. self.post_init()
  971. def get_input_embeddings(self):
  972. return self.layoutlmv2.embeddings.word_embeddings
  973. @add_start_docstrings_to_model_forward(LAYOUTLMV2_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  974. @replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC)
  975. def forward(
  976. self,
  977. input_ids: Optional[torch.LongTensor] = None,
  978. bbox: Optional[torch.LongTensor] = None,
  979. image: Optional[torch.FloatTensor] = None,
  980. attention_mask: Optional[torch.FloatTensor] = None,
  981. token_type_ids: Optional[torch.LongTensor] = None,
  982. position_ids: Optional[torch.LongTensor] = None,
  983. head_mask: Optional[torch.FloatTensor] = None,
  984. inputs_embeds: Optional[torch.FloatTensor] = None,
  985. labels: Optional[torch.LongTensor] = None,
  986. output_attentions: Optional[bool] = None,
  987. output_hidden_states: Optional[bool] = None,
  988. return_dict: Optional[bool] = None,
  989. ) -> Union[Tuple, TokenClassifierOutput]:
  990. r"""
  991. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  992. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  993. Returns:
  994. Example:
  995. ```python
  996. >>> from transformers import AutoProcessor, LayoutLMv2ForTokenClassification, set_seed
  997. >>> from PIL import Image
  998. >>> from datasets import load_dataset
  999. >>> set_seed(0)
  1000. >>> datasets = load_dataset("nielsr/funsd", split="test", trust_remote_code=True)
  1001. >>> labels = datasets.features["ner_tags"].feature.names
  1002. >>> id2label = {v: k for v, k in enumerate(labels)}
  1003. >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv2-base-uncased", revision="no_ocr")
  1004. >>> model = LayoutLMv2ForTokenClassification.from_pretrained(
  1005. ... "microsoft/layoutlmv2-base-uncased", num_labels=len(labels)
  1006. ... )
  1007. >>> data = datasets[0]
  1008. >>> image = Image.open(data["image_path"]).convert("RGB")
  1009. >>> words = data["words"]
  1010. >>> boxes = data["bboxes"] # make sure to normalize your bounding boxes
  1011. >>> word_labels = data["ner_tags"]
  1012. >>> encoding = processor(
  1013. ... image,
  1014. ... words,
  1015. ... boxes=boxes,
  1016. ... word_labels=word_labels,
  1017. ... padding="max_length",
  1018. ... truncation=True,
  1019. ... return_tensors="pt",
  1020. ... )
  1021. >>> outputs = model(**encoding)
  1022. >>> logits, loss = outputs.logits, outputs.loss
  1023. >>> predicted_token_class_ids = logits.argmax(-1)
  1024. >>> predicted_tokens_classes = [id2label[t.item()] for t in predicted_token_class_ids[0]]
  1025. >>> predicted_tokens_classes[:5] # results are not good without further fine-tuning
  1026. ['I-HEADER', 'I-HEADER', 'I-QUESTION', 'I-HEADER', 'I-QUESTION']
  1027. ```
  1028. """
  1029. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1030. outputs = self.layoutlmv2(
  1031. input_ids=input_ids,
  1032. bbox=bbox,
  1033. image=image,
  1034. attention_mask=attention_mask,
  1035. token_type_ids=token_type_ids,
  1036. position_ids=position_ids,
  1037. head_mask=head_mask,
  1038. inputs_embeds=inputs_embeds,
  1039. output_attentions=output_attentions,
  1040. output_hidden_states=output_hidden_states,
  1041. return_dict=return_dict,
  1042. )
  1043. if input_ids is not None:
  1044. input_shape = input_ids.size()
  1045. else:
  1046. input_shape = inputs_embeds.size()[:-1]
  1047. seq_length = input_shape[1]
  1048. # only take the text part of the output representations
  1049. sequence_output = outputs[0][:, :seq_length]
  1050. sequence_output = self.dropout(sequence_output)
  1051. logits = self.classifier(sequence_output)
  1052. loss = None
  1053. if labels is not None:
  1054. loss_fct = CrossEntropyLoss()
  1055. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1056. if not return_dict:
  1057. output = (logits,) + outputs[2:]
  1058. return ((loss,) + output) if loss is not None else output
  1059. return TokenClassifierOutput(
  1060. loss=loss,
  1061. logits=logits,
  1062. hidden_states=outputs.hidden_states,
  1063. attentions=outputs.attentions,
  1064. )
  1065. @add_start_docstrings(
  1066. """
  1067. LayoutLMv2 Model with a span classification head on top for extractive question-answering tasks such as
  1068. [DocVQA](https://rrc.cvc.uab.es/?ch=17) (a linear layer on top of the text part of the hidden-states output to
  1069. compute `span start logits` and `span end logits`).
  1070. """,
  1071. LAYOUTLMV2_START_DOCSTRING,
  1072. )
  1073. class LayoutLMv2ForQuestionAnswering(LayoutLMv2PreTrainedModel):
  1074. def __init__(self, config, has_visual_segment_embedding=True):
  1075. super().__init__(config)
  1076. self.num_labels = config.num_labels
  1077. config.has_visual_segment_embedding = has_visual_segment_embedding
  1078. self.layoutlmv2 = LayoutLMv2Model(config)
  1079. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  1080. # Initialize weights and apply final processing
  1081. self.post_init()
  1082. def get_input_embeddings(self):
  1083. return self.layoutlmv2.embeddings.word_embeddings
  1084. @add_start_docstrings_to_model_forward(LAYOUTLMV2_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  1085. @replace_return_docstrings(output_type=QuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC)
  1086. def forward(
  1087. self,
  1088. input_ids: Optional[torch.LongTensor] = None,
  1089. bbox: Optional[torch.LongTensor] = None,
  1090. image: Optional[torch.FloatTensor] = None,
  1091. attention_mask: Optional[torch.FloatTensor] = None,
  1092. token_type_ids: Optional[torch.LongTensor] = None,
  1093. position_ids: Optional[torch.LongTensor] = None,
  1094. head_mask: Optional[torch.FloatTensor] = None,
  1095. inputs_embeds: Optional[torch.FloatTensor] = None,
  1096. start_positions: Optional[torch.LongTensor] = None,
  1097. end_positions: Optional[torch.LongTensor] = None,
  1098. output_attentions: Optional[bool] = None,
  1099. output_hidden_states: Optional[bool] = None,
  1100. return_dict: Optional[bool] = None,
  1101. ) -> Union[Tuple, QuestionAnsweringModelOutput]:
  1102. r"""
  1103. start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1104. Labels for position (index) of the start of the labelled span for computing the token classification loss.
  1105. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
  1106. are not taken into account for computing the loss.
  1107. end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1108. Labels for position (index) of the end of the labelled span for computing the token classification loss.
  1109. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
  1110. are not taken into account for computing the loss.
  1111. Returns:
  1112. Example:
  1113. In this example below, we give the LayoutLMv2 model an image (of texts) and ask it a question. It will give us
  1114. a prediction of what it thinks the answer is (the span of the answer within the texts parsed from the image).
  1115. ```python
  1116. >>> from transformers import AutoProcessor, LayoutLMv2ForQuestionAnswering, set_seed
  1117. >>> import torch
  1118. >>> from PIL import Image
  1119. >>> from datasets import load_dataset
  1120. >>> set_seed(0)
  1121. >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv2-base-uncased")
  1122. >>> model = LayoutLMv2ForQuestionAnswering.from_pretrained("microsoft/layoutlmv2-base-uncased")
  1123. >>> dataset = load_dataset("hf-internal-testing/fixtures_docvqa", trust_remote_code=True)
  1124. >>> image_path = dataset["test"][0]["file"]
  1125. >>> image = Image.open(image_path).convert("RGB")
  1126. >>> question = "When is coffee break?"
  1127. >>> encoding = processor(image, question, return_tensors="pt")
  1128. >>> outputs = model(**encoding)
  1129. >>> predicted_start_idx = outputs.start_logits.argmax(-1).item()
  1130. >>> predicted_end_idx = outputs.end_logits.argmax(-1).item()
  1131. >>> predicted_start_idx, predicted_end_idx
  1132. (30, 191)
  1133. >>> predicted_answer_tokens = encoding.input_ids.squeeze()[predicted_start_idx : predicted_end_idx + 1]
  1134. >>> predicted_answer = processor.tokenizer.decode(predicted_answer_tokens)
  1135. >>> predicted_answer # results are not good without further fine-tuning
  1136. '44 a. m. to 12 : 25 p. m. 12 : 25 to 12 : 58 p. m. 12 : 58 to 4 : 00 p. m. 2 : 00 to 5 : 00 p. m. coffee break coffee will be served for men and women in the lobby adjacent to exhibit area. please move into exhibit area. ( exhibits open ) trrf general session ( part | ) presiding : lee a. waller trrf vice president “ introductory remarks ” lee a. waller, trrf vice presi - dent individual interviews with trrf public board members and sci - entific advisory council mem - bers conducted by trrf treasurer philip g. kuehn to get answers which the public refrigerated warehousing industry is looking for. plus questions from'
  1137. ```
  1138. ```python
  1139. >>> target_start_index = torch.tensor([7])
  1140. >>> target_end_index = torch.tensor([14])
  1141. >>> outputs = model(**encoding, start_positions=target_start_index, end_positions=target_end_index)
  1142. >>> predicted_answer_span_start = outputs.start_logits.argmax(-1).item()
  1143. >>> predicted_answer_span_end = outputs.end_logits.argmax(-1).item()
  1144. >>> predicted_answer_span_start, predicted_answer_span_end
  1145. (30, 191)
  1146. ```
  1147. """
  1148. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1149. outputs = self.layoutlmv2(
  1150. input_ids=input_ids,
  1151. bbox=bbox,
  1152. image=image,
  1153. attention_mask=attention_mask,
  1154. token_type_ids=token_type_ids,
  1155. position_ids=position_ids,
  1156. head_mask=head_mask,
  1157. inputs_embeds=inputs_embeds,
  1158. output_attentions=output_attentions,
  1159. output_hidden_states=output_hidden_states,
  1160. return_dict=return_dict,
  1161. )
  1162. if input_ids is not None:
  1163. input_shape = input_ids.size()
  1164. else:
  1165. input_shape = inputs_embeds.size()[:-1]
  1166. seq_length = input_shape[1]
  1167. # only take the text part of the output representations
  1168. sequence_output = outputs[0][:, :seq_length]
  1169. logits = self.qa_outputs(sequence_output)
  1170. start_logits, end_logits = logits.split(1, dim=-1)
  1171. start_logits = start_logits.squeeze(-1).contiguous()
  1172. end_logits = end_logits.squeeze(-1).contiguous()
  1173. total_loss = None
  1174. if start_positions is not None and end_positions is not None:
  1175. # If we are on multi-GPU, split add a dimension
  1176. if len(start_positions.size()) > 1:
  1177. start_positions = start_positions.squeeze(-1)
  1178. if len(end_positions.size()) > 1:
  1179. end_positions = end_positions.squeeze(-1)
  1180. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1181. ignored_index = start_logits.size(1)
  1182. start_positions = start_positions.clamp(0, ignored_index)
  1183. end_positions = end_positions.clamp(0, ignored_index)
  1184. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1185. start_loss = loss_fct(start_logits, start_positions)
  1186. end_loss = loss_fct(end_logits, end_positions)
  1187. total_loss = (start_loss + end_loss) / 2
  1188. if not return_dict:
  1189. output = (start_logits, end_logits) + outputs[2:]
  1190. return ((total_loss,) + output) if total_loss is not None else output
  1191. return QuestionAnsweringModelOutput(
  1192. loss=total_loss,
  1193. start_logits=start_logits,
  1194. end_logits=end_logits,
  1195. hidden_states=outputs.hidden_states,
  1196. attentions=outputs.attentions,
  1197. )