modeling_dpt.py 56 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372
  1. # coding=utf-8
  2. # Copyright 2022 Intel Labs, OpenMMLab and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch DPT (Dense Prediction Transformers) model.
  16. This implementation is heavily inspired by OpenMMLab's implementation, found here:
  17. https://github.com/open-mmlab/mmsegmentation/blob/master/mmseg/models/decode_heads/dpt_head.py.
  18. """
  19. import collections.abc
  20. import math
  21. from dataclasses import dataclass
  22. from typing import List, Optional, Set, Tuple, Union
  23. import torch
  24. import torch.utils.checkpoint
  25. from torch import nn
  26. from torch.nn import CrossEntropyLoss
  27. from ...activations import ACT2FN
  28. from ...file_utils import (
  29. add_code_sample_docstrings,
  30. add_start_docstrings,
  31. add_start_docstrings_to_model_forward,
  32. replace_return_docstrings,
  33. )
  34. from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticSegmenterOutput
  35. from ...modeling_utils import PreTrainedModel
  36. from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
  37. from ...utils import ModelOutput, logging, torch_int
  38. from ...utils.backbone_utils import load_backbone
  39. from .configuration_dpt import DPTConfig
  40. logger = logging.get_logger(__name__)
  41. # General docstring
  42. _CONFIG_FOR_DOC = "DPTConfig"
  43. # Base docstring
  44. _CHECKPOINT_FOR_DOC = "Intel/dpt-large"
  45. _EXPECTED_OUTPUT_SHAPE = [1, 577, 1024]
  46. @dataclass
  47. class BaseModelOutputWithIntermediateActivations(ModelOutput):
  48. """
  49. Base class for model's outputs that also contains intermediate activations that can be used at later stages. Useful
  50. in the context of Vision models.:
  51. Args:
  52. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  53. Sequence of hidden-states at the output of the last layer of the model.
  54. intermediate_activations (`tuple(torch.FloatTensor)`, *optional*):
  55. Intermediate activations that can be used to compute hidden states of the model at various layers.
  56. """
  57. last_hidden_states: torch.FloatTensor = None
  58. intermediate_activations: Optional[Tuple[torch.FloatTensor, ...]] = None
  59. @dataclass
  60. class BaseModelOutputWithPoolingAndIntermediateActivations(ModelOutput):
  61. """
  62. Base class for model's outputs that also contains a pooling of the last hidden states as well as intermediate
  63. activations that can be used by the model at later stages.
  64. Args:
  65. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  66. Sequence of hidden-states at the output of the last layer of the model.
  67. pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
  68. Last layer hidden-state of the first token of the sequence (classification token) after further processing
  69. through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
  70. the classification token after processing through a linear layer and a tanh activation function. The linear
  71. layer weights are trained from the next sentence prediction (classification) objective during pretraining.
  72. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  73. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  74. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  75. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  76. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  77. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  78. sequence_length)`.
  79. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  80. heads.
  81. intermediate_activations (`tuple(torch.FloatTensor)`, *optional*):
  82. Intermediate activations that can be used to compute hidden states of the model at various layers.
  83. """
  84. last_hidden_state: torch.FloatTensor = None
  85. pooler_output: torch.FloatTensor = None
  86. hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
  87. attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
  88. intermediate_activations: Optional[Tuple[torch.FloatTensor, ...]] = None
  89. class DPTViTHybridEmbeddings(nn.Module):
  90. """
  91. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  92. `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
  93. Transformer.
  94. """
  95. def __init__(self, config, feature_size=None):
  96. super().__init__()
  97. image_size, patch_size = config.image_size, config.patch_size
  98. num_channels, hidden_size = config.num_channels, config.hidden_size
  99. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  100. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  101. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  102. self.backbone = load_backbone(config)
  103. feature_dim = self.backbone.channels[-1]
  104. if len(self.backbone.channels) != 3:
  105. raise ValueError(f"Expected backbone to have 3 output features, got {len(self.backbone.channels)}")
  106. self.residual_feature_map_index = [0, 1] # Always take the output of the first and second backbone stage
  107. if feature_size is None:
  108. feat_map_shape = config.backbone_featmap_shape
  109. feature_size = feat_map_shape[-2:]
  110. feature_dim = feat_map_shape[1]
  111. else:
  112. feature_size = (
  113. feature_size if isinstance(feature_size, collections.abc.Iterable) else (feature_size, feature_size)
  114. )
  115. feature_dim = self.backbone.channels[-1]
  116. self.image_size = image_size
  117. self.patch_size = patch_size[0]
  118. self.num_channels = num_channels
  119. self.projection = nn.Conv2d(feature_dim, hidden_size, kernel_size=1)
  120. self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  121. self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
  122. def _resize_pos_embed(self, posemb, grid_size_height, grid_size_width, start_index=1):
  123. posemb_tok = posemb[:, :start_index]
  124. posemb_grid = posemb[0, start_index:]
  125. old_grid_size = torch_int(len(posemb_grid) ** 0.5)
  126. posemb_grid = posemb_grid.reshape(1, old_grid_size, old_grid_size, -1).permute(0, 3, 1, 2)
  127. posemb_grid = nn.functional.interpolate(posemb_grid, size=(grid_size_height, grid_size_width), mode="bilinear")
  128. posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, grid_size_height * grid_size_width, -1)
  129. posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
  130. return posemb
  131. def forward(
  132. self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False, return_dict: bool = False
  133. ) -> torch.Tensor:
  134. batch_size, num_channels, height, width = pixel_values.shape
  135. if num_channels != self.num_channels:
  136. raise ValueError(
  137. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  138. )
  139. if not interpolate_pos_encoding:
  140. if height != self.image_size[0] or width != self.image_size[1]:
  141. raise ValueError(
  142. f"Input image size ({height}*{width}) doesn't match model"
  143. f" ({self.image_size[0]}*{self.image_size[1]})."
  144. )
  145. position_embeddings = self._resize_pos_embed(
  146. self.position_embeddings, height // self.patch_size, width // self.patch_size
  147. )
  148. backbone_output = self.backbone(pixel_values)
  149. features = backbone_output.feature_maps[-1]
  150. # Retrieve also the intermediate activations to use them at later stages
  151. output_hidden_states = [backbone_output.feature_maps[index] for index in self.residual_feature_map_index]
  152. embeddings = self.projection(features).flatten(2).transpose(1, 2)
  153. cls_tokens = self.cls_token.expand(batch_size, -1, -1)
  154. embeddings = torch.cat((cls_tokens, embeddings), dim=1)
  155. # add positional encoding to each token
  156. embeddings = embeddings + position_embeddings
  157. if not return_dict:
  158. return (embeddings, output_hidden_states)
  159. # Return hidden states and intermediate activations
  160. return BaseModelOutputWithIntermediateActivations(
  161. last_hidden_states=embeddings,
  162. intermediate_activations=output_hidden_states,
  163. )
  164. class DPTViTEmbeddings(nn.Module):
  165. """
  166. Construct the CLS token, position and patch embeddings.
  167. """
  168. def __init__(self, config):
  169. super().__init__()
  170. self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  171. self.patch_embeddings = DPTViTPatchEmbeddings(config)
  172. num_patches = self.patch_embeddings.num_patches
  173. self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
  174. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  175. self.config = config
  176. def _resize_pos_embed(self, posemb, grid_size_height, grid_size_width, start_index=1):
  177. posemb_tok = posemb[:, :start_index]
  178. posemb_grid = posemb[0, start_index:]
  179. old_grid_size = torch_int(posemb_grid.size(0) ** 0.5)
  180. posemb_grid = posemb_grid.reshape(1, old_grid_size, old_grid_size, -1).permute(0, 3, 1, 2)
  181. posemb_grid = nn.functional.interpolate(posemb_grid, size=(grid_size_height, grid_size_width), mode="bilinear")
  182. posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, grid_size_height * grid_size_width, -1)
  183. posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
  184. return posemb
  185. def forward(self, pixel_values, return_dict=False):
  186. batch_size, num_channels, height, width = pixel_values.shape
  187. # possibly interpolate position encodings to handle varying image sizes
  188. patch_size = self.config.patch_size
  189. position_embeddings = self._resize_pos_embed(
  190. self.position_embeddings, height // patch_size, width // patch_size
  191. )
  192. embeddings = self.patch_embeddings(pixel_values)
  193. batch_size, seq_len, _ = embeddings.size()
  194. # add the [CLS] token to the embedded patch tokens
  195. cls_tokens = self.cls_token.expand(batch_size, -1, -1)
  196. embeddings = torch.cat((cls_tokens, embeddings), dim=1)
  197. # add positional encoding to each token
  198. embeddings = embeddings + position_embeddings
  199. embeddings = self.dropout(embeddings)
  200. if not return_dict:
  201. return (embeddings,)
  202. return BaseModelOutputWithIntermediateActivations(last_hidden_states=embeddings)
  203. class DPTViTPatchEmbeddings(nn.Module):
  204. """
  205. Image to Patch Embedding.
  206. """
  207. def __init__(self, config):
  208. super().__init__()
  209. image_size, patch_size = config.image_size, config.patch_size
  210. num_channels, hidden_size = config.num_channels, config.hidden_size
  211. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  212. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  213. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  214. self.image_size = image_size
  215. self.patch_size = patch_size
  216. self.num_channels = num_channels
  217. self.num_patches = num_patches
  218. self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
  219. def forward(self, pixel_values):
  220. batch_size, num_channels, height, width = pixel_values.shape
  221. if num_channels != self.num_channels:
  222. raise ValueError(
  223. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  224. )
  225. embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
  226. return embeddings
  227. # Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->DPT
  228. class DPTViTSelfAttention(nn.Module):
  229. def __init__(self, config: DPTConfig) -> None:
  230. super().__init__()
  231. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  232. raise ValueError(
  233. f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
  234. f"heads {config.num_attention_heads}."
  235. )
  236. self.num_attention_heads = config.num_attention_heads
  237. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  238. self.all_head_size = self.num_attention_heads * self.attention_head_size
  239. self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  240. self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  241. self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  242. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  243. def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
  244. new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
  245. x = x.view(new_x_shape)
  246. return x.permute(0, 2, 1, 3)
  247. def forward(
  248. self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
  249. ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
  250. mixed_query_layer = self.query(hidden_states)
  251. key_layer = self.transpose_for_scores(self.key(hidden_states))
  252. value_layer = self.transpose_for_scores(self.value(hidden_states))
  253. query_layer = self.transpose_for_scores(mixed_query_layer)
  254. # Take the dot product between "query" and "key" to get the raw attention scores.
  255. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  256. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  257. # Normalize the attention scores to probabilities.
  258. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  259. # This is actually dropping out entire tokens to attend to, which might
  260. # seem a bit unusual, but is taken from the original Transformer paper.
  261. attention_probs = self.dropout(attention_probs)
  262. # Mask heads if we want to
  263. if head_mask is not None:
  264. attention_probs = attention_probs * head_mask
  265. context_layer = torch.matmul(attention_probs, value_layer)
  266. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  267. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  268. context_layer = context_layer.view(new_context_layer_shape)
  269. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  270. return outputs
  271. # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->DPT
  272. class DPTViTSelfOutput(nn.Module):
  273. """
  274. The residual connection is defined in DPTLayer instead of here (as is the case with other models), due to the
  275. layernorm applied before each block.
  276. """
  277. def __init__(self, config: DPTConfig) -> None:
  278. super().__init__()
  279. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  280. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  281. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  282. hidden_states = self.dense(hidden_states)
  283. hidden_states = self.dropout(hidden_states)
  284. return hidden_states
  285. class DPTViTAttention(nn.Module):
  286. def __init__(self, config: DPTConfig) -> None:
  287. super().__init__()
  288. self.attention = DPTViTSelfAttention(config)
  289. self.output = DPTViTSelfOutput(config)
  290. self.pruned_heads = set()
  291. # Copied from transformers.models.vit.modeling_vit.ViTAttention.prune_heads
  292. def prune_heads(self, heads: Set[int]) -> None:
  293. if len(heads) == 0:
  294. return
  295. heads, index = find_pruneable_heads_and_indices(
  296. heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
  297. )
  298. # Prune linear layers
  299. self.attention.query = prune_linear_layer(self.attention.query, index)
  300. self.attention.key = prune_linear_layer(self.attention.key, index)
  301. self.attention.value = prune_linear_layer(self.attention.value, index)
  302. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  303. # Update hyper params and store pruned heads
  304. self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
  305. self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
  306. self.pruned_heads = self.pruned_heads.union(heads)
  307. # Copied from transformers.models.vit.modeling_vit.ViTAttention.forward
  308. def forward(
  309. self,
  310. hidden_states: torch.Tensor,
  311. head_mask: Optional[torch.Tensor] = None,
  312. output_attentions: bool = False,
  313. ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
  314. self_outputs = self.attention(hidden_states, head_mask, output_attentions)
  315. attention_output = self.output(self_outputs[0], hidden_states)
  316. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  317. return outputs
  318. # Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->DPT
  319. class DPTViTIntermediate(nn.Module):
  320. def __init__(self, config: DPTConfig) -> None:
  321. super().__init__()
  322. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  323. if isinstance(config.hidden_act, str):
  324. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  325. else:
  326. self.intermediate_act_fn = config.hidden_act
  327. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  328. hidden_states = self.dense(hidden_states)
  329. hidden_states = self.intermediate_act_fn(hidden_states)
  330. return hidden_states
  331. # Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->DPT
  332. class DPTViTOutput(nn.Module):
  333. def __init__(self, config: DPTConfig) -> None:
  334. super().__init__()
  335. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  336. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  337. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  338. hidden_states = self.dense(hidden_states)
  339. hidden_states = self.dropout(hidden_states)
  340. hidden_states = hidden_states + input_tensor
  341. return hidden_states
  342. # copied from transformers.models.vit.modeling_vit.ViTLayer with ViTConfig->DPTConfig, ViTAttention->DPTViTAttention, ViTIntermediate->DPTViTIntermediate, ViTOutput->DPTViTOutput
  343. class DPTViTLayer(nn.Module):
  344. """This corresponds to the Block class in the timm implementation."""
  345. def __init__(self, config: DPTConfig) -> None:
  346. super().__init__()
  347. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  348. self.seq_len_dim = 1
  349. self.attention = DPTViTAttention(config)
  350. self.intermediate = DPTViTIntermediate(config)
  351. self.output = DPTViTOutput(config)
  352. self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  353. self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  354. def forward(
  355. self,
  356. hidden_states: torch.Tensor,
  357. head_mask: Optional[torch.Tensor] = None,
  358. output_attentions: bool = False,
  359. ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
  360. self_attention_outputs = self.attention(
  361. self.layernorm_before(hidden_states), # in ViT, layernorm is applied before self-attention
  362. head_mask,
  363. output_attentions=output_attentions,
  364. )
  365. attention_output = self_attention_outputs[0]
  366. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  367. # first residual connection
  368. hidden_states = attention_output + hidden_states
  369. # in ViT, layernorm is also applied after self-attention
  370. layer_output = self.layernorm_after(hidden_states)
  371. layer_output = self.intermediate(layer_output)
  372. # second residual connection is done here
  373. layer_output = self.output(layer_output, hidden_states)
  374. outputs = (layer_output,) + outputs
  375. return outputs
  376. # copied from transformers.models.vit.modeling_vit.ViTEncoder with ViTConfig -> DPTConfig, ViTLayer->DPTViTLayer
  377. class DPTViTEncoder(nn.Module):
  378. def __init__(self, config: DPTConfig) -> None:
  379. super().__init__()
  380. self.config = config
  381. self.layer = nn.ModuleList([DPTViTLayer(config) for _ in range(config.num_hidden_layers)])
  382. self.gradient_checkpointing = False
  383. def forward(
  384. self,
  385. hidden_states: torch.Tensor,
  386. head_mask: Optional[torch.Tensor] = None,
  387. output_attentions: bool = False,
  388. output_hidden_states: bool = False,
  389. return_dict: bool = True,
  390. ) -> Union[tuple, BaseModelOutput]:
  391. all_hidden_states = () if output_hidden_states else None
  392. all_self_attentions = () if output_attentions else None
  393. for i, layer_module in enumerate(self.layer):
  394. if output_hidden_states:
  395. all_hidden_states = all_hidden_states + (hidden_states,)
  396. layer_head_mask = head_mask[i] if head_mask is not None else None
  397. if self.gradient_checkpointing and self.training:
  398. layer_outputs = self._gradient_checkpointing_func(
  399. layer_module.__call__,
  400. hidden_states,
  401. layer_head_mask,
  402. output_attentions,
  403. )
  404. else:
  405. layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
  406. hidden_states = layer_outputs[0]
  407. if output_attentions:
  408. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  409. if output_hidden_states:
  410. all_hidden_states = all_hidden_states + (hidden_states,)
  411. if not return_dict:
  412. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  413. return BaseModelOutput(
  414. last_hidden_state=hidden_states,
  415. hidden_states=all_hidden_states,
  416. attentions=all_self_attentions,
  417. )
  418. class DPTReassembleStage(nn.Module):
  419. """
  420. This class reassembles the hidden states of the backbone into image-like feature representations at various
  421. resolutions.
  422. This happens in 3 stages:
  423. 1. Map the N + 1 tokens to a set of N tokens, by taking into account the readout ([CLS]) token according to
  424. `config.readout_type`.
  425. 2. Project the channel dimension of the hidden states according to `config.neck_hidden_sizes`.
  426. 3. Resizing the spatial dimensions (height, width).
  427. Args:
  428. config (`[DPTConfig]`):
  429. Model configuration class defining the model architecture.
  430. """
  431. def __init__(self, config):
  432. super().__init__()
  433. self.config = config
  434. self.layers = nn.ModuleList()
  435. if config.is_hybrid:
  436. self._init_reassemble_dpt_hybrid(config)
  437. else:
  438. self._init_reassemble_dpt(config)
  439. self.neck_ignore_stages = config.neck_ignore_stages
  440. def _init_reassemble_dpt_hybrid(self, config):
  441. r""" "
  442. For DPT-Hybrid the first 2 reassemble layers are set to `nn.Identity()`, please check the official
  443. implementation: https://github.com/isl-org/DPT/blob/f43ef9e08d70a752195028a51be5e1aff227b913/dpt/vit.py#L438
  444. for more details.
  445. """
  446. for i, factor in zip(range(len(config.neck_hidden_sizes)), config.reassemble_factors):
  447. if i <= 1:
  448. self.layers.append(nn.Identity())
  449. elif i > 1:
  450. self.layers.append(DPTReassembleLayer(config, channels=config.neck_hidden_sizes[i], factor=factor))
  451. if config.readout_type != "project":
  452. raise ValueError(f"Readout type {config.readout_type} is not supported for DPT-Hybrid.")
  453. # When using DPT-Hybrid the readout type is set to "project". The sanity check is done on the config file
  454. self.readout_projects = nn.ModuleList()
  455. hidden_size = _get_backbone_hidden_size(config)
  456. for i in range(len(config.neck_hidden_sizes)):
  457. if i <= 1:
  458. self.readout_projects.append(nn.Sequential(nn.Identity()))
  459. elif i > 1:
  460. self.readout_projects.append(
  461. nn.Sequential(nn.Linear(2 * hidden_size, hidden_size), ACT2FN[config.hidden_act])
  462. )
  463. def _init_reassemble_dpt(self, config):
  464. for i, factor in zip(range(len(config.neck_hidden_sizes)), config.reassemble_factors):
  465. self.layers.append(DPTReassembleLayer(config, channels=config.neck_hidden_sizes[i], factor=factor))
  466. if config.readout_type == "project":
  467. self.readout_projects = nn.ModuleList()
  468. hidden_size = _get_backbone_hidden_size(config)
  469. for _ in range(len(config.neck_hidden_sizes)):
  470. self.readout_projects.append(
  471. nn.Sequential(nn.Linear(2 * hidden_size, hidden_size), ACT2FN[config.hidden_act])
  472. )
  473. def forward(self, hidden_states: List[torch.Tensor], patch_height=None, patch_width=None) -> List[torch.Tensor]:
  474. """
  475. Args:
  476. hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length + 1, hidden_size)`):
  477. List of hidden states from the backbone.
  478. """
  479. out = []
  480. for i, hidden_state in enumerate(hidden_states):
  481. if i not in self.neck_ignore_stages:
  482. # reshape to (batch_size, num_channels, height, width)
  483. cls_token, hidden_state = hidden_state[:, 0], hidden_state[:, 1:]
  484. batch_size, sequence_length, num_channels = hidden_state.shape
  485. if patch_height is not None and patch_width is not None:
  486. hidden_state = hidden_state.reshape(batch_size, patch_height, patch_width, num_channels)
  487. else:
  488. size = torch_int(sequence_length**0.5)
  489. hidden_state = hidden_state.reshape(batch_size, size, size, num_channels)
  490. hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
  491. feature_shape = hidden_state.shape
  492. if self.config.readout_type == "project":
  493. # reshape to (batch_size, height*width, num_channels)
  494. hidden_state = hidden_state.flatten(2).permute((0, 2, 1))
  495. readout = cls_token.unsqueeze(1).expand_as(hidden_state)
  496. # concatenate the readout token to the hidden states and project
  497. hidden_state = self.readout_projects[i](torch.cat((hidden_state, readout), -1))
  498. # reshape back to (batch_size, num_channels, height, width)
  499. hidden_state = hidden_state.permute(0, 2, 1).reshape(feature_shape)
  500. elif self.config.readout_type == "add":
  501. hidden_state = hidden_state.flatten(2) + cls_token.unsqueeze(-1)
  502. hidden_state = hidden_state.reshape(feature_shape)
  503. hidden_state = self.layers[i](hidden_state)
  504. out.append(hidden_state)
  505. return out
  506. def _get_backbone_hidden_size(config):
  507. if config.backbone_config is not None and config.is_hybrid is False:
  508. return config.backbone_config.hidden_size
  509. else:
  510. return config.hidden_size
  511. class DPTReassembleLayer(nn.Module):
  512. def __init__(self, config, channels, factor):
  513. super().__init__()
  514. # projection
  515. hidden_size = _get_backbone_hidden_size(config)
  516. self.projection = nn.Conv2d(in_channels=hidden_size, out_channels=channels, kernel_size=1)
  517. # up/down sampling depending on factor
  518. if factor > 1:
  519. self.resize = nn.ConvTranspose2d(channels, channels, kernel_size=factor, stride=factor, padding=0)
  520. elif factor == 1:
  521. self.resize = nn.Identity()
  522. elif factor < 1:
  523. # so should downsample
  524. self.resize = nn.Conv2d(channels, channels, kernel_size=3, stride=int(1 / factor), padding=1)
  525. def forward(self, hidden_state):
  526. hidden_state = self.projection(hidden_state)
  527. hidden_state = self.resize(hidden_state)
  528. return hidden_state
  529. class DPTFeatureFusionStage(nn.Module):
  530. def __init__(self, config):
  531. super().__init__()
  532. self.layers = nn.ModuleList()
  533. for _ in range(len(config.neck_hidden_sizes)):
  534. self.layers.append(DPTFeatureFusionLayer(config))
  535. def forward(self, hidden_states):
  536. # reversing the hidden_states, we start from the last
  537. hidden_states = hidden_states[::-1]
  538. fused_hidden_states = []
  539. # first layer only uses the last hidden_state
  540. fused_hidden_state = self.layers[0](hidden_states[0])
  541. fused_hidden_states.append(fused_hidden_state)
  542. # looping from the last layer to the second
  543. for hidden_state, layer in zip(hidden_states[1:], self.layers[1:]):
  544. fused_hidden_state = layer(fused_hidden_state, hidden_state)
  545. fused_hidden_states.append(fused_hidden_state)
  546. return fused_hidden_states
  547. class DPTPreActResidualLayer(nn.Module):
  548. """
  549. ResidualConvUnit, pre-activate residual unit.
  550. Args:
  551. config (`[DPTConfig]`):
  552. Model configuration class defining the model architecture.
  553. """
  554. def __init__(self, config):
  555. super().__init__()
  556. self.use_batch_norm = config.use_batch_norm_in_fusion_residual
  557. use_bias_in_fusion_residual = (
  558. config.use_bias_in_fusion_residual
  559. if config.use_bias_in_fusion_residual is not None
  560. else not self.use_batch_norm
  561. )
  562. self.activation1 = nn.ReLU()
  563. self.convolution1 = nn.Conv2d(
  564. config.fusion_hidden_size,
  565. config.fusion_hidden_size,
  566. kernel_size=3,
  567. stride=1,
  568. padding=1,
  569. bias=use_bias_in_fusion_residual,
  570. )
  571. self.activation2 = nn.ReLU()
  572. self.convolution2 = nn.Conv2d(
  573. config.fusion_hidden_size,
  574. config.fusion_hidden_size,
  575. kernel_size=3,
  576. stride=1,
  577. padding=1,
  578. bias=use_bias_in_fusion_residual,
  579. )
  580. if self.use_batch_norm:
  581. self.batch_norm1 = nn.BatchNorm2d(config.fusion_hidden_size)
  582. self.batch_norm2 = nn.BatchNorm2d(config.fusion_hidden_size)
  583. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  584. residual = hidden_state
  585. hidden_state = self.activation1(hidden_state)
  586. hidden_state = self.convolution1(hidden_state)
  587. if self.use_batch_norm:
  588. hidden_state = self.batch_norm1(hidden_state)
  589. hidden_state = self.activation2(hidden_state)
  590. hidden_state = self.convolution2(hidden_state)
  591. if self.use_batch_norm:
  592. hidden_state = self.batch_norm2(hidden_state)
  593. return hidden_state + residual
  594. class DPTFeatureFusionLayer(nn.Module):
  595. """Feature fusion layer, merges feature maps from different stages.
  596. Args:
  597. config (`[DPTConfig]`):
  598. Model configuration class defining the model architecture.
  599. align_corners (`bool`, *optional*, defaults to `True`):
  600. The align_corner setting for bilinear upsample.
  601. """
  602. def __init__(self, config, align_corners=True):
  603. super().__init__()
  604. self.align_corners = align_corners
  605. self.projection = nn.Conv2d(config.fusion_hidden_size, config.fusion_hidden_size, kernel_size=1, bias=True)
  606. self.residual_layer1 = DPTPreActResidualLayer(config)
  607. self.residual_layer2 = DPTPreActResidualLayer(config)
  608. def forward(self, hidden_state, residual=None):
  609. if residual is not None:
  610. if hidden_state.shape != residual.shape:
  611. residual = nn.functional.interpolate(
  612. residual, size=(hidden_state.shape[2], hidden_state.shape[3]), mode="bilinear", align_corners=False
  613. )
  614. hidden_state = hidden_state + self.residual_layer1(residual)
  615. hidden_state = self.residual_layer2(hidden_state)
  616. hidden_state = nn.functional.interpolate(
  617. hidden_state, scale_factor=2, mode="bilinear", align_corners=self.align_corners
  618. )
  619. hidden_state = self.projection(hidden_state)
  620. return hidden_state
  621. class DPTPreTrainedModel(PreTrainedModel):
  622. """
  623. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  624. models.
  625. """
  626. config_class = DPTConfig
  627. base_model_prefix = "dpt"
  628. main_input_name = "pixel_values"
  629. supports_gradient_checkpointing = True
  630. def _init_weights(self, module):
  631. """Initialize the weights"""
  632. if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
  633. # Slightly different from the TF version which uses truncated_normal for initialization
  634. # cf https://github.com/pytorch/pytorch/pull/5617
  635. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  636. if module.bias is not None:
  637. module.bias.data.zero_()
  638. elif isinstance(module, nn.LayerNorm):
  639. module.bias.data.zero_()
  640. module.weight.data.fill_(1.0)
  641. DPT_START_DOCSTRING = r"""
  642. This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
  643. as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
  644. behavior.
  645. Parameters:
  646. config ([`ViTConfig`]): Model configuration class with all the parameters of the model.
  647. Initializing with a config file does not load the weights associated with the model, only the
  648. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  649. """
  650. DPT_INPUTS_DOCSTRING = r"""
  651. Args:
  652. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  653. Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`DPTImageProcessor.__call__`]
  654. for details.
  655. head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  656. Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
  657. - 1 indicates the head is **not masked**,
  658. - 0 indicates the head is **masked**.
  659. output_attentions (`bool`, *optional*):
  660. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  661. tensors for more detail.
  662. output_hidden_states (`bool`, *optional*):
  663. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  664. more detail.
  665. return_dict (`bool`, *optional*):
  666. Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
  667. """
  668. @add_start_docstrings(
  669. "The bare DPT Model transformer outputting raw hidden-states without any specific head on top.",
  670. DPT_START_DOCSTRING,
  671. )
  672. class DPTModel(DPTPreTrainedModel):
  673. def __init__(self, config, add_pooling_layer=True):
  674. super().__init__(config)
  675. self.config = config
  676. # vit encoder
  677. if config.is_hybrid:
  678. self.embeddings = DPTViTHybridEmbeddings(config)
  679. else:
  680. self.embeddings = DPTViTEmbeddings(config)
  681. self.encoder = DPTViTEncoder(config)
  682. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  683. self.pooler = DPTViTPooler(config) if add_pooling_layer else None
  684. # Initialize weights and apply final processing
  685. self.post_init()
  686. def get_input_embeddings(self):
  687. if self.config.is_hybrid:
  688. return self.embeddings
  689. else:
  690. return self.embeddings.patch_embeddings
  691. def _prune_heads(self, heads_to_prune):
  692. """
  693. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  694. class PreTrainedModel
  695. """
  696. for layer, heads in heads_to_prune.items():
  697. self.encoder.layer[layer].attention.prune_heads(heads)
  698. @add_start_docstrings_to_model_forward(DPT_INPUTS_DOCSTRING)
  699. @add_code_sample_docstrings(
  700. checkpoint=_CHECKPOINT_FOR_DOC,
  701. output_type=BaseModelOutputWithPoolingAndIntermediateActivations,
  702. config_class=_CONFIG_FOR_DOC,
  703. modality="vision",
  704. expected_output=_EXPECTED_OUTPUT_SHAPE,
  705. )
  706. def forward(
  707. self,
  708. pixel_values: torch.FloatTensor,
  709. head_mask: Optional[torch.FloatTensor] = None,
  710. output_attentions: Optional[bool] = None,
  711. output_hidden_states: Optional[bool] = None,
  712. return_dict: Optional[bool] = None,
  713. ) -> Union[Tuple, BaseModelOutputWithPoolingAndIntermediateActivations]:
  714. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  715. output_hidden_states = (
  716. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  717. )
  718. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  719. # Prepare head mask if needed
  720. # 1.0 in head_mask indicate we keep the head
  721. # attention_probs has shape bsz x n_heads x N x N
  722. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  723. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  724. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  725. embedding_output = self.embeddings(pixel_values, return_dict=return_dict)
  726. embedding_last_hidden_states = embedding_output[0] if not return_dict else embedding_output.last_hidden_states
  727. encoder_outputs = self.encoder(
  728. embedding_last_hidden_states,
  729. head_mask=head_mask,
  730. output_attentions=output_attentions,
  731. output_hidden_states=output_hidden_states,
  732. return_dict=return_dict,
  733. )
  734. sequence_output = encoder_outputs[0]
  735. sequence_output = self.layernorm(sequence_output)
  736. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  737. if not return_dict:
  738. head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
  739. return head_outputs + encoder_outputs[1:] + embedding_output[1:]
  740. return BaseModelOutputWithPoolingAndIntermediateActivations(
  741. last_hidden_state=sequence_output,
  742. pooler_output=pooled_output,
  743. hidden_states=encoder_outputs.hidden_states,
  744. attentions=encoder_outputs.attentions,
  745. intermediate_activations=embedding_output.intermediate_activations,
  746. )
  747. # Copied from transformers.models.vit.modeling_vit.ViTPooler with ViT->DPT
  748. class DPTViTPooler(nn.Module):
  749. def __init__(self, config: DPTConfig):
  750. super().__init__()
  751. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  752. self.activation = nn.Tanh()
  753. def forward(self, hidden_states):
  754. # We "pool" the model by simply taking the hidden state corresponding
  755. # to the first token.
  756. first_token_tensor = hidden_states[:, 0]
  757. pooled_output = self.dense(first_token_tensor)
  758. pooled_output = self.activation(pooled_output)
  759. return pooled_output
  760. class DPTNeck(nn.Module):
  761. """
  762. DPTNeck. A neck is a module that is normally used between the backbone and the head. It takes a list of tensors as
  763. input and produces another list of tensors as output. For DPT, it includes 2 stages:
  764. * DPTReassembleStage
  765. * DPTFeatureFusionStage.
  766. Args:
  767. config (dict): config dict.
  768. """
  769. def __init__(self, config):
  770. super().__init__()
  771. self.config = config
  772. # postprocessing: only required in case of a non-hierarchical backbone (e.g. ViT, BEiT)
  773. if config.backbone_config is not None and config.backbone_config.model_type in ["swinv2"]:
  774. self.reassemble_stage = None
  775. else:
  776. self.reassemble_stage = DPTReassembleStage(config)
  777. self.convs = nn.ModuleList()
  778. for channel in config.neck_hidden_sizes:
  779. self.convs.append(nn.Conv2d(channel, config.fusion_hidden_size, kernel_size=3, padding=1, bias=False))
  780. # fusion
  781. self.fusion_stage = DPTFeatureFusionStage(config)
  782. def forward(self, hidden_states: List[torch.Tensor], patch_height=None, patch_width=None) -> List[torch.Tensor]:
  783. """
  784. Args:
  785. hidden_states (`List[torch.FloatTensor]`, each of shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, hidden_size, height, width)`):
  786. List of hidden states from the backbone.
  787. """
  788. if not isinstance(hidden_states, (tuple, list)):
  789. raise TypeError("hidden_states should be a tuple or list of tensors")
  790. if len(hidden_states) != len(self.config.neck_hidden_sizes):
  791. raise ValueError("The number of hidden states should be equal to the number of neck hidden sizes.")
  792. # postprocess hidden states
  793. if self.reassemble_stage is not None:
  794. hidden_states = self.reassemble_stage(hidden_states, patch_height, patch_width)
  795. features = [self.convs[i](feature) for i, feature in enumerate(hidden_states)]
  796. # fusion blocks
  797. output = self.fusion_stage(features)
  798. return output
  799. class DPTDepthEstimationHead(nn.Module):
  800. """
  801. Output head consisting of 3 convolutional layers. It progressively halves the feature dimension and upsamples
  802. the predictions to the input resolution after the first convolutional layer (details can be found in the paper's
  803. supplementary material).
  804. """
  805. def __init__(self, config):
  806. super().__init__()
  807. self.config = config
  808. self.projection = None
  809. if config.add_projection:
  810. self.projection = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  811. features = config.fusion_hidden_size
  812. self.head = nn.Sequential(
  813. nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
  814. nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
  815. nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
  816. nn.ReLU(),
  817. nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
  818. nn.ReLU(),
  819. )
  820. def forward(self, hidden_states: List[torch.Tensor]) -> torch.Tensor:
  821. # use last features
  822. hidden_states = hidden_states[self.config.head_in_index]
  823. if self.projection is not None:
  824. hidden_states = self.projection(hidden_states)
  825. hidden_states = nn.ReLU()(hidden_states)
  826. predicted_depth = self.head(hidden_states)
  827. predicted_depth = predicted_depth.squeeze(dim=1)
  828. return predicted_depth
  829. @add_start_docstrings(
  830. """
  831. DPT Model with a depth estimation head on top (consisting of 3 convolutional layers) e.g. for KITTI, NYUv2.
  832. """,
  833. DPT_START_DOCSTRING,
  834. )
  835. class DPTForDepthEstimation(DPTPreTrainedModel):
  836. def __init__(self, config):
  837. super().__init__(config)
  838. self.backbone = None
  839. if config.is_hybrid is False and (config.backbone_config is not None or config.backbone is not None):
  840. self.backbone = load_backbone(config)
  841. else:
  842. self.dpt = DPTModel(config, add_pooling_layer=False)
  843. # Neck
  844. self.neck = DPTNeck(config)
  845. # Depth estimation head
  846. self.head = DPTDepthEstimationHead(config)
  847. # Initialize weights and apply final processing
  848. self.post_init()
  849. @add_start_docstrings_to_model_forward(DPT_INPUTS_DOCSTRING)
  850. @replace_return_docstrings(output_type=DepthEstimatorOutput, config_class=_CONFIG_FOR_DOC)
  851. def forward(
  852. self,
  853. pixel_values: torch.FloatTensor,
  854. head_mask: Optional[torch.FloatTensor] = None,
  855. labels: Optional[torch.LongTensor] = None,
  856. output_attentions: Optional[bool] = None,
  857. output_hidden_states: Optional[bool] = None,
  858. return_dict: Optional[bool] = None,
  859. ) -> Union[Tuple[torch.Tensor], DepthEstimatorOutput]:
  860. r"""
  861. labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
  862. Ground truth depth estimation maps for computing the loss.
  863. Returns:
  864. Examples:
  865. ```python
  866. >>> from transformers import AutoImageProcessor, DPTForDepthEstimation
  867. >>> import torch
  868. >>> import numpy as np
  869. >>> from PIL import Image
  870. >>> import requests
  871. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  872. >>> image = Image.open(requests.get(url, stream=True).raw)
  873. >>> image_processor = AutoImageProcessor.from_pretrained("Intel/dpt-large")
  874. >>> model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")
  875. >>> # prepare image for the model
  876. >>> inputs = image_processor(images=image, return_tensors="pt")
  877. >>> with torch.no_grad():
  878. ... outputs = model(**inputs)
  879. >>> # interpolate to original size
  880. >>> post_processed_output = image_processor.post_process_depth_estimation(
  881. ... outputs,
  882. ... target_sizes=[(image.height, image.width)],
  883. ... )
  884. >>> # visualize the prediction
  885. >>> predicted_depth = post_processed_output[0]["predicted_depth"]
  886. >>> depth = predicted_depth * 255 / predicted_depth.max()
  887. >>> depth = depth.detach().cpu().numpy()
  888. >>> depth = Image.fromarray(depth.astype("uint8"))
  889. ```"""
  890. loss = None
  891. if labels is not None:
  892. raise NotImplementedError("Training is not implemented yet")
  893. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  894. output_hidden_states = (
  895. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  896. )
  897. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  898. if self.backbone is not None:
  899. outputs = self.backbone.forward_with_filtered_kwargs(
  900. pixel_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions
  901. )
  902. hidden_states = outputs.feature_maps
  903. else:
  904. outputs = self.dpt(
  905. pixel_values,
  906. head_mask=head_mask,
  907. output_attentions=output_attentions,
  908. output_hidden_states=True, # we need the intermediate hidden states
  909. return_dict=return_dict,
  910. )
  911. hidden_states = outputs.hidden_states if return_dict else outputs[1]
  912. # only keep certain features based on config.backbone_out_indices
  913. # note that the hidden_states also include the initial embeddings
  914. if not self.config.is_hybrid:
  915. hidden_states = [
  916. feature for idx, feature in enumerate(hidden_states[1:]) if idx in self.config.backbone_out_indices
  917. ]
  918. else:
  919. backbone_hidden_states = outputs.intermediate_activations if return_dict else list(outputs[-1])
  920. backbone_hidden_states.extend(
  921. feature
  922. for idx, feature in enumerate(hidden_states[1:])
  923. if idx in self.config.backbone_out_indices[2:]
  924. )
  925. hidden_states = backbone_hidden_states
  926. patch_height, patch_width = None, None
  927. if self.config.backbone_config is not None and self.config.is_hybrid is False:
  928. _, _, height, width = pixel_values.shape
  929. patch_size = self.config.backbone_config.patch_size
  930. patch_height = height // patch_size
  931. patch_width = width // patch_size
  932. hidden_states = self.neck(hidden_states, patch_height, patch_width)
  933. predicted_depth = self.head(hidden_states)
  934. if not return_dict:
  935. if output_hidden_states:
  936. output = (predicted_depth,) + outputs[1:]
  937. else:
  938. output = (predicted_depth,) + outputs[2:]
  939. return ((loss,) + output) if loss is not None else output
  940. return DepthEstimatorOutput(
  941. loss=loss,
  942. predicted_depth=predicted_depth,
  943. hidden_states=outputs.hidden_states if output_hidden_states else None,
  944. attentions=outputs.attentions,
  945. )
  946. class DPTSemanticSegmentationHead(nn.Module):
  947. def __init__(self, config):
  948. super().__init__()
  949. self.config = config
  950. features = config.fusion_hidden_size
  951. self.head = nn.Sequential(
  952. nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False),
  953. nn.BatchNorm2d(features),
  954. nn.ReLU(),
  955. nn.Dropout(config.semantic_classifier_dropout),
  956. nn.Conv2d(features, config.num_labels, kernel_size=1),
  957. nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
  958. )
  959. def forward(self, hidden_states: List[torch.Tensor]) -> torch.Tensor:
  960. # use last features
  961. hidden_states = hidden_states[self.config.head_in_index]
  962. logits = self.head(hidden_states)
  963. return logits
  964. class DPTAuxiliaryHead(nn.Module):
  965. def __init__(self, config):
  966. super().__init__()
  967. features = config.fusion_hidden_size
  968. self.head = nn.Sequential(
  969. nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False),
  970. nn.BatchNorm2d(features),
  971. nn.ReLU(),
  972. nn.Dropout(0.1, False),
  973. nn.Conv2d(features, config.num_labels, kernel_size=1),
  974. )
  975. def forward(self, hidden_states):
  976. logits = self.head(hidden_states)
  977. return logits
  978. @add_start_docstrings(
  979. """
  980. DPT Model with a semantic segmentation head on top e.g. for ADE20k, CityScapes.
  981. """,
  982. DPT_START_DOCSTRING,
  983. )
  984. class DPTForSemanticSegmentation(DPTPreTrainedModel):
  985. def __init__(self, config):
  986. super().__init__(config)
  987. self.dpt = DPTModel(config, add_pooling_layer=False)
  988. # Neck
  989. self.neck = DPTNeck(config)
  990. # Segmentation head(s)
  991. self.head = DPTSemanticSegmentationHead(config)
  992. self.auxiliary_head = DPTAuxiliaryHead(config) if config.use_auxiliary_head else None
  993. # Initialize weights and apply final processing
  994. self.post_init()
  995. @add_start_docstrings_to_model_forward(DPT_INPUTS_DOCSTRING)
  996. @replace_return_docstrings(output_type=SemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC)
  997. def forward(
  998. self,
  999. pixel_values: Optional[torch.FloatTensor] = None,
  1000. head_mask: Optional[torch.FloatTensor] = None,
  1001. labels: Optional[torch.LongTensor] = None,
  1002. output_attentions: Optional[bool] = None,
  1003. output_hidden_states: Optional[bool] = None,
  1004. return_dict: Optional[bool] = None,
  1005. ) -> Union[Tuple[torch.Tensor], SemanticSegmenterOutput]:
  1006. r"""
  1007. labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
  1008. Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
  1009. config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
  1010. Returns:
  1011. Examples:
  1012. ```python
  1013. >>> from transformers import AutoImageProcessor, DPTForSemanticSegmentation
  1014. >>> from PIL import Image
  1015. >>> import requests
  1016. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1017. >>> image = Image.open(requests.get(url, stream=True).raw)
  1018. >>> image_processor = AutoImageProcessor.from_pretrained("Intel/dpt-large-ade")
  1019. >>> model = DPTForSemanticSegmentation.from_pretrained("Intel/dpt-large-ade")
  1020. >>> inputs = image_processor(images=image, return_tensors="pt")
  1021. >>> outputs = model(**inputs)
  1022. >>> logits = outputs.logits
  1023. ```"""
  1024. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1025. output_hidden_states = (
  1026. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1027. )
  1028. if labels is not None and self.config.num_labels == 1:
  1029. raise ValueError("The number of labels should be greater than one")
  1030. outputs = self.dpt(
  1031. pixel_values,
  1032. head_mask=head_mask,
  1033. output_attentions=output_attentions,
  1034. output_hidden_states=True, # we need the intermediate hidden states
  1035. return_dict=return_dict,
  1036. )
  1037. hidden_states = outputs.hidden_states if return_dict else outputs[1]
  1038. # only keep certain features based on config.backbone_out_indices
  1039. # note that the hidden_states also include the initial embeddings
  1040. if not self.config.is_hybrid:
  1041. hidden_states = [
  1042. feature for idx, feature in enumerate(hidden_states[1:]) if idx in self.config.backbone_out_indices
  1043. ]
  1044. else:
  1045. backbone_hidden_states = outputs.intermediate_activations if return_dict else list(outputs[-1])
  1046. backbone_hidden_states.extend(
  1047. feature for idx, feature in enumerate(hidden_states[1:]) if idx in self.config.backbone_out_indices[2:]
  1048. )
  1049. hidden_states = backbone_hidden_states
  1050. hidden_states = self.neck(hidden_states=hidden_states)
  1051. logits = self.head(hidden_states)
  1052. auxiliary_logits = None
  1053. if self.auxiliary_head is not None:
  1054. auxiliary_logits = self.auxiliary_head(hidden_states[-1])
  1055. loss = None
  1056. if labels is not None:
  1057. # upsample logits to the images' original size
  1058. upsampled_logits = nn.functional.interpolate(
  1059. logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
  1060. )
  1061. if auxiliary_logits is not None:
  1062. upsampled_auxiliary_logits = nn.functional.interpolate(
  1063. auxiliary_logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
  1064. )
  1065. # compute weighted loss
  1066. loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
  1067. main_loss = loss_fct(upsampled_logits, labels)
  1068. auxiliary_loss = loss_fct(upsampled_auxiliary_logits, labels)
  1069. loss = main_loss + self.config.auxiliary_loss_weight * auxiliary_loss
  1070. if not return_dict:
  1071. if output_hidden_states:
  1072. output = (logits,) + outputs[1:]
  1073. else:
  1074. output = (logits,) + outputs[2:]
  1075. return ((loss,) + output) if loss is not None else output
  1076. return SemanticSegmenterOutput(
  1077. loss=loss,
  1078. logits=logits,
  1079. hidden_states=outputs.hidden_states if output_hidden_states else None,
  1080. attentions=outputs.attentions,
  1081. )