modeling_glpn.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775
  1. # coding=utf-8
  2. # Copyright 2022 KAIST and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch GLPN model."""
  16. import math
  17. from typing import List, Optional, Tuple, Union
  18. import torch
  19. import torch.utils.checkpoint
  20. from torch import nn
  21. from ...activations import ACT2FN
  22. from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput
  23. from ...modeling_utils import PreTrainedModel
  24. from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
  25. from ...utils import (
  26. add_code_sample_docstrings,
  27. add_start_docstrings,
  28. add_start_docstrings_to_model_forward,
  29. logging,
  30. replace_return_docstrings,
  31. )
  32. from .configuration_glpn import GLPNConfig
  33. logger = logging.get_logger(__name__)
  34. # General docstring
  35. _CONFIG_FOR_DOC = "GLPNConfig"
  36. # Base docstring
  37. _CHECKPOINT_FOR_DOC = "vinvino02/glpn-kitti"
  38. _EXPECTED_OUTPUT_SHAPE = [1, 512, 15, 20]
  39. # Copied from transformers.models.beit.modeling_beit.drop_path
  40. def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
  41. """
  42. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  43. Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
  44. however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
  45. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
  46. layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
  47. argument.
  48. """
  49. if drop_prob == 0.0 or not training:
  50. return input
  51. keep_prob = 1 - drop_prob
  52. shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  53. random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
  54. random_tensor.floor_() # binarize
  55. output = input.div(keep_prob) * random_tensor
  56. return output
  57. # Copied from transformers.models.segformer.modeling_segformer.SegformerDropPath
  58. class GLPNDropPath(nn.Module):
  59. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  60. def __init__(self, drop_prob: Optional[float] = None) -> None:
  61. super().__init__()
  62. self.drop_prob = drop_prob
  63. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  64. return drop_path(hidden_states, self.drop_prob, self.training)
  65. def extra_repr(self) -> str:
  66. return "p={}".format(self.drop_prob)
  67. # Copied from transformers.models.segformer.modeling_segformer.SegformerOverlapPatchEmbeddings
  68. class GLPNOverlapPatchEmbeddings(nn.Module):
  69. """Construct the overlapping patch embeddings."""
  70. def __init__(self, patch_size, stride, num_channels, hidden_size):
  71. super().__init__()
  72. self.proj = nn.Conv2d(
  73. num_channels,
  74. hidden_size,
  75. kernel_size=patch_size,
  76. stride=stride,
  77. padding=patch_size // 2,
  78. )
  79. self.layer_norm = nn.LayerNorm(hidden_size)
  80. def forward(self, pixel_values):
  81. embeddings = self.proj(pixel_values)
  82. _, _, height, width = embeddings.shape
  83. # (batch_size, num_channels, height, width) -> (batch_size, num_channels, height*width) -> (batch_size, height*width, num_channels)
  84. # this can be fed to a Transformer layer
  85. embeddings = embeddings.flatten(2).transpose(1, 2)
  86. embeddings = self.layer_norm(embeddings)
  87. return embeddings, height, width
  88. # Copied from transformers.models.segformer.modeling_segformer.SegformerEfficientSelfAttention
  89. class GLPNEfficientSelfAttention(nn.Module):
  90. """SegFormer's efficient self-attention mechanism. Employs the sequence reduction process introduced in the [PvT
  91. paper](https://arxiv.org/abs/2102.12122)."""
  92. def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_ratio):
  93. super().__init__()
  94. self.hidden_size = hidden_size
  95. self.num_attention_heads = num_attention_heads
  96. if self.hidden_size % self.num_attention_heads != 0:
  97. raise ValueError(
  98. f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention "
  99. f"heads ({self.num_attention_heads})"
  100. )
  101. self.attention_head_size = int(self.hidden_size / self.num_attention_heads)
  102. self.all_head_size = self.num_attention_heads * self.attention_head_size
  103. self.query = nn.Linear(self.hidden_size, self.all_head_size)
  104. self.key = nn.Linear(self.hidden_size, self.all_head_size)
  105. self.value = nn.Linear(self.hidden_size, self.all_head_size)
  106. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  107. self.sr_ratio = sequence_reduction_ratio
  108. if sequence_reduction_ratio > 1:
  109. self.sr = nn.Conv2d(
  110. hidden_size, hidden_size, kernel_size=sequence_reduction_ratio, stride=sequence_reduction_ratio
  111. )
  112. self.layer_norm = nn.LayerNorm(hidden_size)
  113. def transpose_for_scores(self, hidden_states):
  114. new_shape = hidden_states.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
  115. hidden_states = hidden_states.view(new_shape)
  116. return hidden_states.permute(0, 2, 1, 3)
  117. def forward(
  118. self,
  119. hidden_states,
  120. height,
  121. width,
  122. output_attentions=False,
  123. ):
  124. query_layer = self.transpose_for_scores(self.query(hidden_states))
  125. if self.sr_ratio > 1:
  126. batch_size, seq_len, num_channels = hidden_states.shape
  127. # Reshape to (batch_size, num_channels, height, width)
  128. hidden_states = hidden_states.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)
  129. # Apply sequence reduction
  130. hidden_states = self.sr(hidden_states)
  131. # Reshape back to (batch_size, seq_len, num_channels)
  132. hidden_states = hidden_states.reshape(batch_size, num_channels, -1).permute(0, 2, 1)
  133. hidden_states = self.layer_norm(hidden_states)
  134. key_layer = self.transpose_for_scores(self.key(hidden_states))
  135. value_layer = self.transpose_for_scores(self.value(hidden_states))
  136. # Take the dot product between "query" and "key" to get the raw attention scores.
  137. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  138. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  139. # Normalize the attention scores to probabilities.
  140. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  141. # This is actually dropping out entire tokens to attend to, which might
  142. # seem a bit unusual, but is taken from the original Transformer paper.
  143. attention_probs = self.dropout(attention_probs)
  144. context_layer = torch.matmul(attention_probs, value_layer)
  145. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  146. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  147. context_layer = context_layer.view(new_context_layer_shape)
  148. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  149. return outputs
  150. # Copied from transformers.models.segformer.modeling_segformer.SegformerSelfOutput
  151. class GLPNSelfOutput(nn.Module):
  152. def __init__(self, config, hidden_size):
  153. super().__init__()
  154. self.dense = nn.Linear(hidden_size, hidden_size)
  155. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  156. def forward(self, hidden_states, input_tensor):
  157. hidden_states = self.dense(hidden_states)
  158. hidden_states = self.dropout(hidden_states)
  159. return hidden_states
  160. # Copied from transformers.models.segformer.modeling_segformer.SegformerAttention with Segformer->GLPN
  161. class GLPNAttention(nn.Module):
  162. def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_ratio):
  163. super().__init__()
  164. self.self = GLPNEfficientSelfAttention(
  165. config=config,
  166. hidden_size=hidden_size,
  167. num_attention_heads=num_attention_heads,
  168. sequence_reduction_ratio=sequence_reduction_ratio,
  169. )
  170. self.output = GLPNSelfOutput(config, hidden_size=hidden_size)
  171. self.pruned_heads = set()
  172. def prune_heads(self, heads):
  173. if len(heads) == 0:
  174. return
  175. heads, index = find_pruneable_heads_and_indices(
  176. heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
  177. )
  178. # Prune linear layers
  179. self.self.query = prune_linear_layer(self.self.query, index)
  180. self.self.key = prune_linear_layer(self.self.key, index)
  181. self.self.value = prune_linear_layer(self.self.value, index)
  182. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  183. # Update hyper params and store pruned heads
  184. self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
  185. self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
  186. self.pruned_heads = self.pruned_heads.union(heads)
  187. def forward(self, hidden_states, height, width, output_attentions=False):
  188. self_outputs = self.self(hidden_states, height, width, output_attentions)
  189. attention_output = self.output(self_outputs[0], hidden_states)
  190. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  191. return outputs
  192. # Copied from transformers.models.segformer.modeling_segformer.SegformerDWConv
  193. class GLPNDWConv(nn.Module):
  194. def __init__(self, dim=768):
  195. super().__init__()
  196. self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
  197. def forward(self, hidden_states, height, width):
  198. batch_size, seq_len, num_channels = hidden_states.shape
  199. hidden_states = hidden_states.transpose(1, 2).view(batch_size, num_channels, height, width)
  200. hidden_states = self.dwconv(hidden_states)
  201. hidden_states = hidden_states.flatten(2).transpose(1, 2)
  202. return hidden_states
  203. # Copied from transformers.models.segformer.modeling_segformer.SegformerMixFFN with Segformer->GLPN
  204. class GLPNMixFFN(nn.Module):
  205. def __init__(self, config, in_features, hidden_features=None, out_features=None):
  206. super().__init__()
  207. out_features = out_features or in_features
  208. self.dense1 = nn.Linear(in_features, hidden_features)
  209. self.dwconv = GLPNDWConv(hidden_features)
  210. if isinstance(config.hidden_act, str):
  211. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  212. else:
  213. self.intermediate_act_fn = config.hidden_act
  214. self.dense2 = nn.Linear(hidden_features, out_features)
  215. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  216. def forward(self, hidden_states, height, width):
  217. hidden_states = self.dense1(hidden_states)
  218. hidden_states = self.dwconv(hidden_states, height, width)
  219. hidden_states = self.intermediate_act_fn(hidden_states)
  220. hidden_states = self.dropout(hidden_states)
  221. hidden_states = self.dense2(hidden_states)
  222. hidden_states = self.dropout(hidden_states)
  223. return hidden_states
  224. # Copied from transformers.models.segformer.modeling_segformer.SegformerLayer with Segformer->GLPN
  225. class GLPNLayer(nn.Module):
  226. """This corresponds to the Block class in the original implementation."""
  227. def __init__(self, config, hidden_size, num_attention_heads, drop_path, sequence_reduction_ratio, mlp_ratio):
  228. super().__init__()
  229. self.layer_norm_1 = nn.LayerNorm(hidden_size)
  230. self.attention = GLPNAttention(
  231. config,
  232. hidden_size=hidden_size,
  233. num_attention_heads=num_attention_heads,
  234. sequence_reduction_ratio=sequence_reduction_ratio,
  235. )
  236. self.drop_path = GLPNDropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  237. self.layer_norm_2 = nn.LayerNorm(hidden_size)
  238. mlp_hidden_size = int(hidden_size * mlp_ratio)
  239. self.mlp = GLPNMixFFN(config, in_features=hidden_size, hidden_features=mlp_hidden_size)
  240. def forward(self, hidden_states, height, width, output_attentions=False):
  241. self_attention_outputs = self.attention(
  242. self.layer_norm_1(hidden_states), # in GLPN, layernorm is applied before self-attention
  243. height,
  244. width,
  245. output_attentions=output_attentions,
  246. )
  247. attention_output = self_attention_outputs[0]
  248. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  249. # first residual connection (with stochastic depth)
  250. attention_output = self.drop_path(attention_output)
  251. hidden_states = attention_output + hidden_states
  252. mlp_output = self.mlp(self.layer_norm_2(hidden_states), height, width)
  253. # second residual connection (with stochastic depth)
  254. mlp_output = self.drop_path(mlp_output)
  255. layer_output = mlp_output + hidden_states
  256. outputs = (layer_output,) + outputs
  257. return outputs
  258. class GLPNEncoder(nn.Module):
  259. def __init__(self, config):
  260. super().__init__()
  261. self.config = config
  262. # stochastic depth decay rule
  263. dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
  264. # patch embeddings
  265. embeddings = []
  266. for i in range(config.num_encoder_blocks):
  267. embeddings.append(
  268. GLPNOverlapPatchEmbeddings(
  269. patch_size=config.patch_sizes[i],
  270. stride=config.strides[i],
  271. num_channels=config.num_channels if i == 0 else config.hidden_sizes[i - 1],
  272. hidden_size=config.hidden_sizes[i],
  273. )
  274. )
  275. self.patch_embeddings = nn.ModuleList(embeddings)
  276. # Transformer blocks
  277. blocks = []
  278. cur = 0
  279. for i in range(config.num_encoder_blocks):
  280. # each block consists of layers
  281. layers = []
  282. if i != 0:
  283. cur += config.depths[i - 1]
  284. for j in range(config.depths[i]):
  285. layers.append(
  286. GLPNLayer(
  287. config,
  288. hidden_size=config.hidden_sizes[i],
  289. num_attention_heads=config.num_attention_heads[i],
  290. drop_path=dpr[cur + j],
  291. sequence_reduction_ratio=config.sr_ratios[i],
  292. mlp_ratio=config.mlp_ratios[i],
  293. )
  294. )
  295. blocks.append(nn.ModuleList(layers))
  296. self.block = nn.ModuleList(blocks)
  297. # Layer norms
  298. self.layer_norm = nn.ModuleList(
  299. [nn.LayerNorm(config.hidden_sizes[i]) for i in range(config.num_encoder_blocks)]
  300. )
  301. def forward(
  302. self,
  303. pixel_values,
  304. output_attentions=False,
  305. output_hidden_states=False,
  306. return_dict=True,
  307. ):
  308. all_hidden_states = () if output_hidden_states else None
  309. all_self_attentions = () if output_attentions else None
  310. batch_size = pixel_values.shape[0]
  311. hidden_states = pixel_values
  312. for idx, x in enumerate(zip(self.patch_embeddings, self.block, self.layer_norm)):
  313. embedding_layer, block_layer, norm_layer = x
  314. # first, obtain patch embeddings
  315. hidden_states, height, width = embedding_layer(hidden_states)
  316. # second, send embeddings through blocks
  317. for i, blk in enumerate(block_layer):
  318. layer_outputs = blk(hidden_states, height, width, output_attentions)
  319. hidden_states = layer_outputs[0]
  320. if output_attentions:
  321. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  322. # third, apply layer norm
  323. hidden_states = norm_layer(hidden_states)
  324. # fourth, optionally reshape back to (batch_size, num_channels, height, width)
  325. hidden_states = hidden_states.reshape(batch_size, height, width, -1).permute(0, 3, 1, 2).contiguous()
  326. if output_hidden_states:
  327. all_hidden_states = all_hidden_states + (hidden_states,)
  328. if not return_dict:
  329. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  330. return BaseModelOutput(
  331. last_hidden_state=hidden_states,
  332. hidden_states=all_hidden_states,
  333. attentions=all_self_attentions,
  334. )
  335. class GLPNPreTrainedModel(PreTrainedModel):
  336. """
  337. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  338. models.
  339. """
  340. config_class = GLPNConfig
  341. base_model_prefix = "glpn"
  342. main_input_name = "pixel_values"
  343. _no_split_modules = []
  344. # Copied from transformers.models.segformer.modeling_segformer.SegformerPreTrainedModel._init_weights
  345. def _init_weights(self, module):
  346. """Initialize the weights"""
  347. if isinstance(module, (nn.Linear, nn.Conv2d)):
  348. # Slightly different from the TF version which uses truncated_normal for initialization
  349. # cf https://github.com/pytorch/pytorch/pull/5617
  350. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  351. if module.bias is not None:
  352. module.bias.data.zero_()
  353. elif isinstance(module, nn.Embedding):
  354. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  355. if module.padding_idx is not None:
  356. module.weight.data[module.padding_idx].zero_()
  357. elif isinstance(module, nn.LayerNorm):
  358. module.bias.data.zero_()
  359. module.weight.data.fill_(1.0)
  360. GLPN_START_DOCSTRING = r"""
  361. This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
  362. it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
  363. behavior.
  364. Parameters:
  365. config ([`GLPNConfig`]): Model configuration class with all the parameters of the model.
  366. Initializing with a config file does not load the weights associated with the model, only the
  367. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  368. """
  369. GLPN_INPUTS_DOCSTRING = r"""
  370. Args:
  371. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  372. Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
  373. [`AutoImageProcessor`]. See [`GLPNImageProcessor.__call__`] for details.
  374. output_attentions (`bool`, *optional*):
  375. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  376. tensors for more detail.
  377. output_hidden_states (`bool`, *optional*):
  378. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  379. more detail.
  380. return_dict (`bool`, *optional*):
  381. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  382. """
  383. @add_start_docstrings(
  384. "The bare GLPN encoder (Mix-Transformer) outputting raw hidden-states without any specific head on top.",
  385. GLPN_START_DOCSTRING,
  386. )
  387. class GLPNModel(GLPNPreTrainedModel):
  388. # Copied from transformers.models.segformer.modeling_segformer.SegformerModel.__init__ with Segformer->GLPN
  389. def __init__(self, config):
  390. super().__init__(config)
  391. self.config = config
  392. # hierarchical Transformer encoder
  393. self.encoder = GLPNEncoder(config)
  394. # Initialize weights and apply final processing
  395. self.post_init()
  396. def _prune_heads(self, heads_to_prune):
  397. """
  398. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  399. class PreTrainedModel
  400. """
  401. for layer, heads in heads_to_prune.items():
  402. self.encoder.layer[layer].attention.prune_heads(heads)
  403. @add_start_docstrings_to_model_forward(GLPN_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
  404. @add_code_sample_docstrings(
  405. checkpoint=_CHECKPOINT_FOR_DOC,
  406. output_type=BaseModelOutput,
  407. config_class=_CONFIG_FOR_DOC,
  408. modality="vision",
  409. expected_output=_EXPECTED_OUTPUT_SHAPE,
  410. )
  411. # Copied from transformers.models.segformer.modeling_segformer.SegformerModel.forward
  412. def forward(
  413. self,
  414. pixel_values: torch.FloatTensor,
  415. output_attentions: Optional[bool] = None,
  416. output_hidden_states: Optional[bool] = None,
  417. return_dict: Optional[bool] = None,
  418. ) -> Union[Tuple, BaseModelOutput]:
  419. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  420. output_hidden_states = (
  421. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  422. )
  423. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  424. encoder_outputs = self.encoder(
  425. pixel_values,
  426. output_attentions=output_attentions,
  427. output_hidden_states=output_hidden_states,
  428. return_dict=return_dict,
  429. )
  430. sequence_output = encoder_outputs[0]
  431. if not return_dict:
  432. return (sequence_output,) + encoder_outputs[1:]
  433. return BaseModelOutput(
  434. last_hidden_state=sequence_output,
  435. hidden_states=encoder_outputs.hidden_states,
  436. attentions=encoder_outputs.attentions,
  437. )
  438. class GLPNSelectiveFeatureFusion(nn.Module):
  439. """
  440. Selective Feature Fusion module, as explained in the [paper](https://arxiv.org/abs/2201.07436) (section 3.4). This
  441. module adaptively selects and integrates local and global features by attaining an attention map for each feature.
  442. """
  443. def __init__(self, in_channel=64):
  444. super().__init__()
  445. self.convolutional_layer1 = nn.Sequential(
  446. nn.Conv2d(in_channels=int(in_channel * 2), out_channels=in_channel, kernel_size=3, stride=1, padding=1),
  447. nn.BatchNorm2d(in_channel),
  448. nn.ReLU(),
  449. )
  450. self.convolutional_layer2 = nn.Sequential(
  451. nn.Conv2d(in_channels=in_channel, out_channels=int(in_channel / 2), kernel_size=3, stride=1, padding=1),
  452. nn.BatchNorm2d(int(in_channel / 2)),
  453. nn.ReLU(),
  454. )
  455. self.convolutional_layer3 = nn.Conv2d(
  456. in_channels=int(in_channel / 2), out_channels=2, kernel_size=3, stride=1, padding=1
  457. )
  458. self.sigmoid = nn.Sigmoid()
  459. def forward(self, local_features, global_features):
  460. # concatenate features along the channel dimension
  461. features = torch.cat((local_features, global_features), dim=1)
  462. # pass through convolutional layers
  463. features = self.convolutional_layer1(features)
  464. features = self.convolutional_layer2(features)
  465. features = self.convolutional_layer3(features)
  466. # apply sigmoid to get two-channel attention map
  467. attn = self.sigmoid(features)
  468. # construct hybrid features by adding element-wise
  469. hybrid_features = local_features * attn[:, 0, :, :].unsqueeze(1) + global_features * attn[
  470. :, 1, :, :
  471. ].unsqueeze(1)
  472. return hybrid_features
  473. class GLPNDecoderStage(nn.Module):
  474. def __init__(self, in_channels, out_channels):
  475. super().__init__()
  476. should_skip = in_channels == out_channels
  477. self.convolution = nn.Conv2d(in_channels, out_channels, kernel_size=1) if not should_skip else nn.Identity()
  478. self.fusion = GLPNSelectiveFeatureFusion(out_channels)
  479. self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)
  480. def forward(self, hidden_state, residual=None):
  481. hidden_state = self.convolution(hidden_state)
  482. if residual is not None:
  483. hidden_state = self.fusion(hidden_state, residual)
  484. hidden_state = self.upsample(hidden_state)
  485. return hidden_state
  486. hidden_state = self.upsample(hidden_state)
  487. return hidden_state
  488. class GLPNDecoder(nn.Module):
  489. def __init__(self, config):
  490. super().__init__()
  491. # we use features from end -> start
  492. reserved_hidden_sizes = config.hidden_sizes[::-1]
  493. out_channels = config.decoder_hidden_size
  494. self.stages = nn.ModuleList(
  495. [GLPNDecoderStage(hidden_size, out_channels) for hidden_size in reserved_hidden_sizes]
  496. )
  497. # don't fuse in first stage
  498. self.stages[0].fusion = None
  499. self.final_upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)
  500. def forward(self, hidden_states: List[torch.Tensor]) -> List[torch.Tensor]:
  501. stage_hidden_states = []
  502. stage_hidden_state = None
  503. for hidden_state, stage in zip(hidden_states[::-1], self.stages):
  504. stage_hidden_state = stage(hidden_state, stage_hidden_state)
  505. stage_hidden_states.append(stage_hidden_state)
  506. stage_hidden_states[-1] = self.final_upsample(stage_hidden_state)
  507. return stage_hidden_states
  508. class SiLogLoss(nn.Module):
  509. r"""
  510. Implements the Scale-invariant log scale loss [Eigen et al., 2014](https://arxiv.org/abs/1406.2283).
  511. $$L=\frac{1}{n} \sum_{i} d_{i}^{2}-\frac{1}{2 n^{2}}\left(\sum_{i} d_{i}^{2}\right)$$ where $d_{i}=\log y_{i}-\log
  512. y_{i}^{*}$.
  513. """
  514. def __init__(self, lambd=0.5):
  515. super().__init__()
  516. self.lambd = lambd
  517. def forward(self, pred, target):
  518. valid_mask = (target > 0).detach()
  519. diff_log = torch.log(target[valid_mask]) - torch.log(pred[valid_mask])
  520. loss = torch.sqrt(torch.pow(diff_log, 2).mean() - self.lambd * torch.pow(diff_log.mean(), 2))
  521. return loss
  522. class GLPNDepthEstimationHead(nn.Module):
  523. def __init__(self, config):
  524. super().__init__()
  525. self.config = config
  526. channels = config.decoder_hidden_size
  527. self.head = nn.Sequential(
  528. nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
  529. nn.ReLU(inplace=False),
  530. nn.Conv2d(channels, 1, kernel_size=3, stride=1, padding=1),
  531. )
  532. def forward(self, hidden_states: List[torch.Tensor]) -> torch.Tensor:
  533. # use last features of the decoder
  534. hidden_states = hidden_states[self.config.head_in_index]
  535. hidden_states = self.head(hidden_states)
  536. predicted_depth = torch.sigmoid(hidden_states) * self.config.max_depth
  537. predicted_depth = predicted_depth.squeeze(dim=1)
  538. return predicted_depth
  539. @add_start_docstrings(
  540. """GLPN Model transformer with a lightweight depth estimation head on top e.g. for KITTI, NYUv2.""",
  541. GLPN_START_DOCSTRING,
  542. )
  543. class GLPNForDepthEstimation(GLPNPreTrainedModel):
  544. def __init__(self, config):
  545. super().__init__(config)
  546. self.glpn = GLPNModel(config)
  547. self.decoder = GLPNDecoder(config)
  548. self.head = GLPNDepthEstimationHead(config)
  549. # Initialize weights and apply final processing
  550. self.post_init()
  551. @add_start_docstrings_to_model_forward(GLPN_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  552. @replace_return_docstrings(output_type=DepthEstimatorOutput, config_class=_CONFIG_FOR_DOC)
  553. def forward(
  554. self,
  555. pixel_values: torch.FloatTensor,
  556. labels: Optional[torch.FloatTensor] = None,
  557. output_attentions: Optional[bool] = None,
  558. output_hidden_states: Optional[bool] = None,
  559. return_dict: Optional[bool] = None,
  560. ) -> Union[Tuple[torch.Tensor], DepthEstimatorOutput]:
  561. r"""
  562. labels (`torch.FloatTensor` of shape `(batch_size, height, width)`, *optional*):
  563. Ground truth depth estimation maps for computing the loss.
  564. Returns:
  565. Examples:
  566. ```python
  567. >>> from transformers import AutoImageProcessor, GLPNForDepthEstimation
  568. >>> import torch
  569. >>> import numpy as np
  570. >>> from PIL import Image
  571. >>> import requests
  572. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  573. >>> image = Image.open(requests.get(url, stream=True).raw)
  574. >>> image_processor = AutoImageProcessor.from_pretrained("vinvino02/glpn-kitti")
  575. >>> model = GLPNForDepthEstimation.from_pretrained("vinvino02/glpn-kitti")
  576. >>> # prepare image for the model
  577. >>> inputs = image_processor(images=image, return_tensors="pt")
  578. >>> with torch.no_grad():
  579. ... outputs = model(**inputs)
  580. ... predicted_depth = outputs.predicted_depth
  581. >>> # interpolate to original size
  582. >>> prediction = torch.nn.functional.interpolate(
  583. ... predicted_depth.unsqueeze(1),
  584. ... size=image.size[::-1],
  585. ... mode="bicubic",
  586. ... align_corners=False,
  587. ... )
  588. >>> # visualize the prediction
  589. >>> output = prediction.squeeze().cpu().numpy()
  590. >>> formatted = (output * 255 / np.max(output)).astype("uint8")
  591. >>> depth = Image.fromarray(formatted)
  592. ```"""
  593. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  594. output_hidden_states = (
  595. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  596. )
  597. outputs = self.glpn(
  598. pixel_values,
  599. output_attentions=output_attentions,
  600. output_hidden_states=True, # we need the intermediate hidden states
  601. return_dict=return_dict,
  602. )
  603. hidden_states = outputs.hidden_states if return_dict else outputs[1]
  604. out = self.decoder(hidden_states)
  605. predicted_depth = self.head(out)
  606. loss = None
  607. if labels is not None:
  608. loss_fct = SiLogLoss()
  609. loss = loss_fct(predicted_depth, labels)
  610. if not return_dict:
  611. if output_hidden_states:
  612. output = (predicted_depth,) + outputs[1:]
  613. else:
  614. output = (predicted_depth,) + outputs[2:]
  615. return ((loss,) + output) if loss is not None else output
  616. return DepthEstimatorOutput(
  617. loss=loss,
  618. predicted_depth=predicted_depth,
  619. hidden_states=outputs.hidden_states if output_hidden_states else None,
  620. attentions=outputs.attentions,
  621. )