modeling_focalnet.py 42 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029
  1. # coding=utf-8
  2. # Copyright 2023 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch FocalNet model."""
  16. import collections.abc
  17. import math
  18. from dataclasses import dataclass
  19. from typing import Optional, Tuple, Union
  20. import torch
  21. import torch.utils.checkpoint
  22. from torch import nn
  23. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  24. from ...activations import ACT2FN
  25. from ...modeling_outputs import BackboneOutput
  26. from ...modeling_utils import PreTrainedModel
  27. from ...utils import (
  28. ModelOutput,
  29. add_code_sample_docstrings,
  30. add_start_docstrings,
  31. add_start_docstrings_to_model_forward,
  32. logging,
  33. replace_return_docstrings,
  34. )
  35. from ...utils.backbone_utils import BackboneMixin
  36. from .configuration_focalnet import FocalNetConfig
  37. logger = logging.get_logger(__name__)
  38. # General docstring
  39. _CONFIG_FOR_DOC = "FocalNetConfig"
  40. # Base docstring
  41. _CHECKPOINT_FOR_DOC = "microsoft/focalnet-tiny"
  42. _EXPECTED_OUTPUT_SHAPE = [1, 49, 768]
  43. # Image classification docstring
  44. _IMAGE_CLASS_CHECKPOINT = "microsoft/focalnet-tiny"
  45. _IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
  46. @dataclass
  47. class FocalNetEncoderOutput(ModelOutput):
  48. """
  49. FocalNet encoder's outputs, with potential hidden states.
  50. Args:
  51. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  52. Sequence of hidden-states at the output of the last layer of the model.
  53. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  54. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  55. shape `(batch_size, sequence_length, hidden_size)`.
  56. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
  57. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  58. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  59. shape `(batch_size, hidden_size, height, width)`.
  60. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
  61. include the spatial dimensions.
  62. """
  63. last_hidden_state: torch.FloatTensor = None
  64. hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  65. reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  66. @dataclass
  67. class FocalNetModelOutput(ModelOutput):
  68. """
  69. FocalNet model's outputs that also contains a pooling of the last hidden states.
  70. Args:
  71. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  72. Sequence of hidden-states at the output of the last layer of the model.
  73. pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
  74. Average pooling of the last layer hidden-state.
  75. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  76. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  77. shape `(batch_size, sequence_length, hidden_size)`.
  78. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
  79. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  80. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  81. shape `(batch_size, hidden_size, height, width)`.
  82. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
  83. include the spatial dimensions.
  84. """
  85. last_hidden_state: torch.FloatTensor = None
  86. pooler_output: Optional[torch.FloatTensor] = None
  87. hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  88. reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  89. @dataclass
  90. class FocalNetMaskedImageModelingOutput(ModelOutput):
  91. """
  92. FocalNet masked image model outputs.
  93. Args:
  94. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
  95. Masked image modeling (MLM) loss.
  96. reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  97. Reconstructed pixel values.
  98. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  99. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  100. shape `(batch_size, sequence_length, hidden_size)`.
  101. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
  102. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  103. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  104. shape `(batch_size, hidden_size, height, width)`.
  105. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
  106. include the spatial dimensions.
  107. """
  108. loss: Optional[torch.FloatTensor] = None
  109. reconstruction: torch.FloatTensor = None
  110. hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  111. reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  112. @dataclass
  113. class FocalNetImageClassifierOutput(ModelOutput):
  114. """
  115. FocalNet outputs for image classification.
  116. Args:
  117. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  118. Classification (or regression if config.num_labels==1) loss.
  119. logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
  120. Classification (or regression if config.num_labels==1) scores (before SoftMax).
  121. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  122. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  123. shape `(batch_size, sequence_length, hidden_size)`.
  124. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
  125. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  126. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  127. shape `(batch_size, hidden_size, height, width)`.
  128. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
  129. include the spatial dimensions.
  130. """
  131. loss: Optional[torch.FloatTensor] = None
  132. logits: torch.FloatTensor = None
  133. hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  134. reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  135. class FocalNetEmbeddings(nn.Module):
  136. """
  137. Construct the patch embeddings and layernorm. Optionally, also the mask token.
  138. """
  139. def __init__(self, config, use_mask_token=False):
  140. super().__init__()
  141. self.patch_embeddings = FocalNetPatchEmbeddings(
  142. config=config,
  143. image_size=config.image_size,
  144. patch_size=config.patch_size,
  145. num_channels=config.num_channels,
  146. embed_dim=config.embed_dim,
  147. use_conv_embed=config.use_conv_embed,
  148. is_stem=True,
  149. )
  150. self.patch_grid = self.patch_embeddings.grid_size
  151. self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None
  152. self.norm = nn.LayerNorm(config.embed_dim, eps=config.layer_norm_eps)
  153. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  154. def forward(
  155. self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None
  156. ) -> Tuple[torch.Tensor]:
  157. embeddings, output_dimensions = self.patch_embeddings(pixel_values)
  158. embeddings = self.norm(embeddings)
  159. batch_size, seq_len, _ = embeddings.size()
  160. if bool_masked_pos is not None:
  161. mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
  162. # replace the masked visual tokens by mask_tokens
  163. mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
  164. embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
  165. embeddings = self.dropout(embeddings)
  166. return embeddings, output_dimensions
  167. class FocalNetPatchEmbeddings(nn.Module):
  168. def __init__(
  169. self,
  170. config,
  171. image_size,
  172. patch_size,
  173. num_channels,
  174. embed_dim,
  175. add_norm=False,
  176. use_conv_embed=False,
  177. is_stem=False,
  178. ):
  179. super().__init__()
  180. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  181. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  182. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  183. self.image_size = image_size
  184. self.patch_size = patch_size
  185. self.num_channels = num_channels
  186. self.num_patches = num_patches
  187. self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
  188. if use_conv_embed:
  189. # if we choose to use conv embedding, then we treat the stem and non-stem differently
  190. if is_stem:
  191. kernel_size = 7
  192. padding = 2
  193. stride = 4
  194. else:
  195. kernel_size = 3
  196. padding = 1
  197. stride = 2
  198. self.projection = nn.Conv2d(
  199. num_channels, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
  200. )
  201. else:
  202. self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
  203. if add_norm:
  204. self.norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  205. else:
  206. self.norm = None
  207. def maybe_pad(self, pixel_values, height, width):
  208. if width % self.patch_size[1] != 0:
  209. pad_values = (0, self.patch_size[1] - width % self.patch_size[1])
  210. pixel_values = nn.functional.pad(pixel_values, pad_values)
  211. if height % self.patch_size[0] != 0:
  212. pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0])
  213. pixel_values = nn.functional.pad(pixel_values, pad_values)
  214. return pixel_values
  215. def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
  216. _, num_channels, height, width = pixel_values.shape
  217. if num_channels != self.num_channels:
  218. raise ValueError(
  219. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  220. )
  221. # pad the input to be divisible by self.patch_size, if needed
  222. pixel_values = self.maybe_pad(pixel_values, height, width)
  223. embeddings = self.projection(pixel_values)
  224. _, _, height, width = embeddings.shape
  225. output_dimensions = (height, width)
  226. embeddings = embeddings.flatten(2).transpose(1, 2)
  227. if self.norm is not None:
  228. embeddings = self.norm(embeddings)
  229. return embeddings, output_dimensions
  230. # Copied from transformers.models.beit.modeling_beit.drop_path
  231. def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
  232. """
  233. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  234. Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
  235. however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
  236. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
  237. layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
  238. argument.
  239. """
  240. if drop_prob == 0.0 or not training:
  241. return input
  242. keep_prob = 1 - drop_prob
  243. shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  244. random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
  245. random_tensor.floor_() # binarize
  246. output = input.div(keep_prob) * random_tensor
  247. return output
  248. # Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->FocalNet
  249. class FocalNetDropPath(nn.Module):
  250. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  251. def __init__(self, drop_prob: Optional[float] = None) -> None:
  252. super().__init__()
  253. self.drop_prob = drop_prob
  254. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  255. return drop_path(hidden_states, self.drop_prob, self.training)
  256. def extra_repr(self) -> str:
  257. return "p={}".format(self.drop_prob)
  258. class FocalNetModulation(nn.Module):
  259. def __init__(self, config, index, dim, focal_factor=2, bias=True, projection_dropout=0.0):
  260. super().__init__()
  261. self.dim = dim
  262. self.focal_window = config.focal_windows[index]
  263. self.focal_level = config.focal_levels[index]
  264. self.focal_factor = focal_factor
  265. self.use_post_layernorm_in_modulation = config.use_post_layernorm_in_modulation
  266. self.normalize_modulator = config.normalize_modulator
  267. self.projection_in = nn.Linear(dim, 2 * dim + (self.focal_level + 1), bias=bias)
  268. self.projection_context = nn.Conv2d(dim, dim, kernel_size=1, stride=1, bias=bias)
  269. self.activation = nn.GELU()
  270. self.projection_out = nn.Linear(dim, dim)
  271. self.projection_dropout = nn.Dropout(projection_dropout)
  272. self.focal_layers = nn.ModuleList()
  273. self.kernel_sizes = []
  274. for k in range(self.focal_level):
  275. kernel_size = self.focal_factor * k + self.focal_window
  276. self.focal_layers.append(
  277. nn.Sequential(
  278. nn.Conv2d(
  279. dim, dim, kernel_size=kernel_size, stride=1, groups=dim, padding=kernel_size // 2, bias=False
  280. ),
  281. nn.GELU(),
  282. )
  283. )
  284. self.kernel_sizes.append(kernel_size)
  285. if self.use_post_layernorm_in_modulation:
  286. self.layernorm = nn.LayerNorm(dim, eps=config.layer_norm_eps)
  287. def forward(self, hidden_state):
  288. """
  289. Args:
  290. hidden_state:
  291. Input features with shape of (batch_size, height, width, num_channels)
  292. """
  293. num_channels = hidden_state.shape[-1]
  294. # pre linear projection
  295. x = self.projection_in(hidden_state).permute(0, 3, 1, 2).contiguous()
  296. q, ctx, self.gates = torch.split(x, (num_channels, num_channels, self.focal_level + 1), 1)
  297. # context aggreation
  298. ctx_all = 0
  299. for level in range(self.focal_level):
  300. ctx = self.focal_layers[level](ctx)
  301. ctx_all = ctx_all + ctx * self.gates[:, level : level + 1]
  302. ctx_global = self.activation(ctx.mean(2, keepdim=True).mean(3, keepdim=True))
  303. ctx_all = ctx_all + ctx_global * self.gates[:, self.focal_level :]
  304. # normalize context
  305. if self.normalize_modulator:
  306. ctx_all = ctx_all / (self.focal_level + 1)
  307. # focal modulation
  308. self.modulator = self.projection_context(ctx_all)
  309. x_out = q * self.modulator
  310. x_out = x_out.permute(0, 2, 3, 1).contiguous()
  311. if self.use_post_layernorm_in_modulation:
  312. x_out = self.layernorm(x_out)
  313. # post linear porjection
  314. x_out = self.projection_out(x_out)
  315. x_out = self.projection_dropout(x_out)
  316. return x_out
  317. class FocalNetMlp(nn.Module):
  318. def __init__(self, config, in_features, hidden_features=None, out_features=None, drop=0.0):
  319. super().__init__()
  320. out_features = out_features or in_features
  321. hidden_features = hidden_features or in_features
  322. self.fc1 = nn.Linear(in_features, hidden_features)
  323. self.activation = ACT2FN[config.hidden_act]
  324. self.fc2 = nn.Linear(hidden_features, out_features)
  325. self.drop = nn.Dropout(drop)
  326. def forward(self, hidden_state):
  327. hidden_state = self.fc1(hidden_state)
  328. hidden_state = self.activation(hidden_state)
  329. hidden_state = self.drop(hidden_state)
  330. hidden_state = self.fc2(hidden_state)
  331. hidden_state = self.drop(hidden_state)
  332. return hidden_state
  333. class FocalNetLayer(nn.Module):
  334. r"""Focal Modulation Network layer (block).
  335. Args:
  336. config (`FocalNetConfig`):
  337. Model config.
  338. index (`int`):
  339. Layer index.
  340. dim (`int`):
  341. Number of input channels.
  342. input_resolution (`Tuple[int]`):
  343. Input resulotion.
  344. drop_path (`float`, *optional*, defaults to 0.0):
  345. Stochastic depth rate.
  346. """
  347. def __init__(self, config, index, dim, input_resolution, drop_path=0.0):
  348. super().__init__()
  349. self.config = config
  350. # layer-specific attributes
  351. self.dim = dim
  352. self.input_resolution = input_resolution
  353. # general attributes
  354. self.drop = config.hidden_dropout_prob
  355. self.use_post_layernorm = config.use_post_layernorm
  356. self.norm1 = nn.LayerNorm(dim, eps=config.layer_norm_eps)
  357. self.modulation = FocalNetModulation(
  358. config=config,
  359. index=index,
  360. dim=dim,
  361. projection_dropout=self.drop,
  362. )
  363. self.drop_path = FocalNetDropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  364. self.norm2 = nn.LayerNorm(dim, eps=config.layer_norm_eps)
  365. mlp_hidden_dim = int(dim * config.mlp_ratio)
  366. self.mlp = FocalNetMlp(config=config, in_features=dim, hidden_features=mlp_hidden_dim, drop=self.drop)
  367. self.gamma_1 = 1.0
  368. self.gamma_2 = 1.0
  369. if config.use_layerscale:
  370. self.gamma_1 = nn.Parameter(config.layerscale_value * torch.ones((dim)), requires_grad=True)
  371. self.gamma_2 = nn.Parameter(config.layerscale_value * torch.ones((dim)), requires_grad=True)
  372. def forward(self, hidden_state, input_dimensions):
  373. height, width = input_dimensions
  374. batch_size, _, num_channels = hidden_state.shape
  375. shortcut = hidden_state
  376. # Focal Modulation
  377. hidden_state = hidden_state if self.use_post_layernorm else self.norm1(hidden_state)
  378. hidden_state = hidden_state.view(batch_size, height, width, num_channels)
  379. hidden_state = self.modulation(hidden_state).view(batch_size, height * width, num_channels)
  380. hidden_state = hidden_state if not self.use_post_layernorm else self.norm1(hidden_state)
  381. # FFN
  382. hidden_state = shortcut + self.drop_path(self.gamma_1 * hidden_state)
  383. hidden_state = hidden_state + self.drop_path(
  384. self.gamma_2
  385. * (self.norm2(self.mlp(hidden_state)) if self.use_post_layernorm else self.mlp(self.norm2(hidden_state)))
  386. )
  387. return hidden_state
  388. class FocalNetStage(nn.Module):
  389. def __init__(self, config, index, input_resolution):
  390. super().__init__()
  391. self.config = config
  392. self.num_stages = len(config.depths)
  393. embed_dim = [config.embed_dim * (2**i) for i in range(self.num_stages)]
  394. dim = embed_dim[index]
  395. out_dim = embed_dim[index + 1] if (index < self.num_stages - 1) else None
  396. downsample = FocalNetPatchEmbeddings if (index < self.num_stages - 1) else None
  397. # stochastic depth decay rule
  398. dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
  399. drop_path = dpr[sum(config.depths[:index]) : sum(config.depths[: index + 1])]
  400. self.layers = nn.ModuleList(
  401. [
  402. FocalNetLayer(
  403. config=config,
  404. index=index,
  405. dim=dim,
  406. input_resolution=input_resolution,
  407. drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
  408. )
  409. for i in range(config.depths[index])
  410. ]
  411. )
  412. if downsample is not None:
  413. self.downsample = downsample(
  414. config=config,
  415. image_size=input_resolution,
  416. patch_size=2,
  417. num_channels=dim,
  418. embed_dim=out_dim,
  419. add_norm=True,
  420. use_conv_embed=config.use_conv_embed,
  421. is_stem=False,
  422. )
  423. else:
  424. self.downsample = None
  425. self.pointing = False
  426. def forward(self, hidden_states: torch.Tensor, input_dimensions: Tuple[int, int]) -> Tuple[torch.Tensor]:
  427. height, width = input_dimensions
  428. for layer_module in self.layers:
  429. hidden_states = layer_module(hidden_states, input_dimensions)
  430. hidden_states_before_downsampling = hidden_states
  431. if self.downsample is not None:
  432. height, width = input_dimensions
  433. hidden_states = hidden_states.transpose(1, 2).reshape(
  434. hidden_states_before_downsampling.shape[0], -1, height, width
  435. )
  436. hidden_states, output_dimensions = self.downsample(hidden_states)
  437. else:
  438. output_dimensions = (height, width, height, width)
  439. stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions)
  440. return stage_outputs
  441. class FocalNetEncoder(nn.Module):
  442. def __init__(self, config, grid_size):
  443. super().__init__()
  444. self.num_stages = len(config.depths)
  445. self.config = config
  446. self.stages = nn.ModuleList(
  447. [
  448. FocalNetStage(
  449. config=config,
  450. index=i_layer,
  451. input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)),
  452. )
  453. for i_layer in range(self.num_stages)
  454. ]
  455. )
  456. self.gradient_checkpointing = False
  457. def forward(
  458. self,
  459. hidden_states: torch.Tensor,
  460. input_dimensions: Tuple[int, int],
  461. output_hidden_states: Optional[bool] = False,
  462. output_hidden_states_before_downsampling: Optional[bool] = False,
  463. return_dict: Optional[bool] = True,
  464. ) -> Union[Tuple, FocalNetEncoderOutput]:
  465. all_hidden_states = () if output_hidden_states else None
  466. all_reshaped_hidden_states = () if output_hidden_states else None
  467. if output_hidden_states:
  468. batch_size, _, hidden_size = hidden_states.shape
  469. # rearrange b (h w) c -> b c h w
  470. reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
  471. reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
  472. all_hidden_states += (hidden_states,)
  473. all_reshaped_hidden_states += (reshaped_hidden_state,)
  474. for i, stage_module in enumerate(self.stages):
  475. if self.gradient_checkpointing and self.training:
  476. stage_outputs = self._gradient_checkpointing_func(
  477. stage_module.__call__,
  478. hidden_states,
  479. input_dimensions,
  480. )
  481. else:
  482. stage_outputs = stage_module(hidden_states, input_dimensions)
  483. hidden_states = stage_outputs[0]
  484. hidden_states_before_downsampling = stage_outputs[1]
  485. output_dimensions = stage_outputs[2]
  486. input_dimensions = (output_dimensions[-2], output_dimensions[-1])
  487. if output_hidden_states and output_hidden_states_before_downsampling:
  488. batch_size, _, hidden_size = hidden_states_before_downsampling.shape
  489. # rearrange b (h w) c -> b c h w
  490. # here we use the original (not downsampled) height and width
  491. reshaped_hidden_state = hidden_states_before_downsampling.view(
  492. batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size
  493. )
  494. reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
  495. all_hidden_states += (hidden_states_before_downsampling,)
  496. all_reshaped_hidden_states += (reshaped_hidden_state,)
  497. elif output_hidden_states and not output_hidden_states_before_downsampling:
  498. batch_size, _, hidden_size = hidden_states.shape
  499. # rearrange b (h w) c -> b c h w
  500. reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
  501. reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
  502. all_hidden_states += (hidden_states,)
  503. all_reshaped_hidden_states += (reshaped_hidden_state,)
  504. if not return_dict:
  505. return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
  506. return FocalNetEncoderOutput(
  507. last_hidden_state=hidden_states,
  508. hidden_states=all_hidden_states,
  509. reshaped_hidden_states=all_reshaped_hidden_states,
  510. )
  511. # Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->FocalNet,swin->focalnet
  512. class FocalNetPreTrainedModel(PreTrainedModel):
  513. """
  514. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  515. models.
  516. """
  517. config_class = FocalNetConfig
  518. base_model_prefix = "focalnet"
  519. main_input_name = "pixel_values"
  520. supports_gradient_checkpointing = True
  521. _no_split_modules = ["FocalNetStage"]
  522. def _init_weights(self, module):
  523. """Initialize the weights"""
  524. if isinstance(module, (nn.Linear, nn.Conv2d)):
  525. # Slightly different from the TF version which uses truncated_normal for initialization
  526. # cf https://github.com/pytorch/pytorch/pull/5617
  527. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  528. if module.bias is not None:
  529. module.bias.data.zero_()
  530. elif isinstance(module, nn.LayerNorm):
  531. module.bias.data.zero_()
  532. module.weight.data.fill_(1.0)
  533. FOCALNET_START_DOCSTRING = r"""
  534. This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
  535. it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
  536. behavior.
  537. Parameters:
  538. config ([`FocalNetConfig`]): Model configuration class with all the parameters of the model.
  539. Initializing with a config file does not load the weights associated with the model, only the
  540. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  541. """
  542. FOCALNET_INPUTS_DOCSTRING = r"""
  543. Args:
  544. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  545. Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
  546. [`AutoImageProcessor.__call__`] for details.
  547. output_hidden_states (`bool`, *optional*):
  548. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  549. more detail.
  550. return_dict (`bool`, *optional*):
  551. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  552. """
  553. @add_start_docstrings(
  554. "The bare FocalNet Model outputting raw hidden-states without any specific head on top.",
  555. FOCALNET_START_DOCSTRING,
  556. )
  557. class FocalNetModel(FocalNetPreTrainedModel):
  558. def __init__(self, config, add_pooling_layer=True, use_mask_token=False):
  559. super().__init__(config)
  560. self.config = config
  561. self.num_stages = len(config.depths)
  562. self.num_features = int(config.embed_dim * 2 ** (self.num_stages - 1))
  563. self.embeddings = FocalNetEmbeddings(config, use_mask_token=use_mask_token)
  564. self.encoder = FocalNetEncoder(config, self.embeddings.patch_grid)
  565. self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps)
  566. self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
  567. # Initialize weights and apply final processing
  568. self.post_init()
  569. def get_input_embeddings(self):
  570. return self.embeddings.patch_embeddings
  571. @add_start_docstrings_to_model_forward(FOCALNET_INPUTS_DOCSTRING)
  572. @add_code_sample_docstrings(
  573. checkpoint=_CHECKPOINT_FOR_DOC,
  574. output_type=FocalNetModelOutput,
  575. config_class=_CONFIG_FOR_DOC,
  576. modality="vision",
  577. expected_output=_EXPECTED_OUTPUT_SHAPE,
  578. )
  579. def forward(
  580. self,
  581. pixel_values: Optional[torch.FloatTensor] = None,
  582. bool_masked_pos: Optional[torch.BoolTensor] = None,
  583. output_hidden_states: Optional[bool] = None,
  584. return_dict: Optional[bool] = None,
  585. ) -> Union[Tuple, FocalNetModelOutput]:
  586. r"""
  587. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
  588. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  589. """
  590. output_hidden_states = (
  591. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  592. )
  593. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  594. if pixel_values is None:
  595. raise ValueError("You have to specify pixel_values")
  596. embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
  597. encoder_outputs = self.encoder(
  598. embedding_output,
  599. input_dimensions,
  600. output_hidden_states=output_hidden_states,
  601. return_dict=return_dict,
  602. )
  603. sequence_output = encoder_outputs[0]
  604. sequence_output = self.layernorm(sequence_output)
  605. pooled_output = None
  606. if self.pooler is not None:
  607. pooled_output = self.pooler(sequence_output.transpose(1, 2))
  608. pooled_output = torch.flatten(pooled_output, 1)
  609. if not return_dict:
  610. output = (sequence_output, pooled_output) + encoder_outputs[1:]
  611. return output
  612. return FocalNetModelOutput(
  613. last_hidden_state=sequence_output,
  614. pooler_output=pooled_output,
  615. hidden_states=encoder_outputs.hidden_states,
  616. reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
  617. )
  618. @add_start_docstrings(
  619. """FocalNet Model with a decoder on top for masked image modeling.
  620. This follows the same implementation as in [SimMIM](https://arxiv.org/abs/2111.09886).
  621. <Tip>
  622. Note that we provide a script to pre-train this model on custom data in our [examples
  623. directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).
  624. </Tip>
  625. """,
  626. FOCALNET_START_DOCSTRING,
  627. )
  628. class FocalNetForMaskedImageModeling(FocalNetPreTrainedModel):
  629. def __init__(self, config):
  630. super().__init__(config)
  631. self.focalnet = FocalNetModel(config, add_pooling_layer=False, use_mask_token=True)
  632. self.num_stages = len(config.depths)
  633. num_features = int(config.embed_dim * 2 ** (self.num_stages - 1))
  634. self.decoder = nn.Sequential(
  635. nn.Conv2d(
  636. in_channels=num_features, out_channels=config.encoder_stride**2 * config.num_channels, kernel_size=1
  637. ),
  638. nn.PixelShuffle(config.encoder_stride),
  639. )
  640. # Initialize weights and apply final processing
  641. self.post_init()
  642. @add_start_docstrings_to_model_forward(FOCALNET_INPUTS_DOCSTRING)
  643. @replace_return_docstrings(output_type=FocalNetMaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC)
  644. def forward(
  645. self,
  646. pixel_values: Optional[torch.FloatTensor] = None,
  647. bool_masked_pos: Optional[torch.BoolTensor] = None,
  648. output_hidden_states: Optional[bool] = None,
  649. return_dict: Optional[bool] = None,
  650. ) -> Union[Tuple, FocalNetMaskedImageModelingOutput]:
  651. r"""
  652. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
  653. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  654. Returns:
  655. Examples:
  656. ```python
  657. >>> from transformers import AutoImageProcessor, FocalNetConfig, FocalNetForMaskedImageModeling
  658. >>> import torch
  659. >>> from PIL import Image
  660. >>> import requests
  661. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  662. >>> image = Image.open(requests.get(url, stream=True).raw)
  663. >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/focalnet-base-simmim-window6-192")
  664. >>> config = FocalNetConfig()
  665. >>> model = FocalNetForMaskedImageModeling(config)
  666. >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
  667. >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
  668. >>> # create random boolean mask of shape (batch_size, num_patches)
  669. >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
  670. >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
  671. >>> loss, reconstructed_pixel_values = outputs.loss, outputs.logits
  672. >>> list(reconstructed_pixel_values.shape)
  673. [1, 3, 192, 192]
  674. ```"""
  675. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  676. outputs = self.focalnet(
  677. pixel_values,
  678. bool_masked_pos=bool_masked_pos,
  679. output_hidden_states=output_hidden_states,
  680. return_dict=return_dict,
  681. )
  682. sequence_output = outputs[0]
  683. # Reshape to (batch_size, num_channels, height, width)
  684. sequence_output = sequence_output.transpose(1, 2)
  685. batch_size, num_channels, sequence_length = sequence_output.shape
  686. height = width = math.floor(sequence_length**0.5)
  687. sequence_output = sequence_output.reshape(batch_size, num_channels, height, width)
  688. # Reconstruct pixel values
  689. reconstructed_pixel_values = self.decoder(sequence_output)
  690. masked_im_loss = None
  691. if bool_masked_pos is not None:
  692. size = self.config.image_size // self.config.patch_size
  693. bool_masked_pos = bool_masked_pos.reshape(-1, size, size)
  694. mask = (
  695. bool_masked_pos.repeat_interleave(self.config.patch_size, 1)
  696. .repeat_interleave(self.config.patch_size, 2)
  697. .unsqueeze(1)
  698. .contiguous()
  699. )
  700. reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none")
  701. masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels
  702. if not return_dict:
  703. output = (reconstructed_pixel_values,) + outputs[2:]
  704. return ((masked_im_loss,) + output) if masked_im_loss is not None else output
  705. return FocalNetMaskedImageModelingOutput(
  706. loss=masked_im_loss,
  707. reconstruction=reconstructed_pixel_values,
  708. hidden_states=outputs.hidden_states,
  709. reshaped_hidden_states=outputs.reshaped_hidden_states,
  710. )
  711. @add_start_docstrings(
  712. """
  713. FocalNet Model with an image classification head on top (a linear layer on top of the pooled output) e.g. for
  714. ImageNet.
  715. """,
  716. FOCALNET_START_DOCSTRING,
  717. )
  718. class FocalNetForImageClassification(FocalNetPreTrainedModel):
  719. # Copied from transformers.models.swin.modeling_swin.SwinForImageClassification.__init__ with Swin->FocalNet, swin->focalnet
  720. def __init__(self, config):
  721. super().__init__(config)
  722. self.num_labels = config.num_labels
  723. self.focalnet = FocalNetModel(config)
  724. # Classifier head
  725. self.classifier = (
  726. nn.Linear(self.focalnet.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity()
  727. )
  728. # Initialize weights and apply final processing
  729. self.post_init()
  730. @add_start_docstrings_to_model_forward(FOCALNET_INPUTS_DOCSTRING)
  731. @add_code_sample_docstrings(
  732. checkpoint=_IMAGE_CLASS_CHECKPOINT,
  733. output_type=FocalNetImageClassifierOutput,
  734. config_class=_CONFIG_FOR_DOC,
  735. expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
  736. )
  737. def forward(
  738. self,
  739. pixel_values: Optional[torch.FloatTensor] = None,
  740. labels: Optional[torch.LongTensor] = None,
  741. output_hidden_states: Optional[bool] = None,
  742. return_dict: Optional[bool] = None,
  743. ) -> Union[Tuple, FocalNetImageClassifierOutput]:
  744. r"""
  745. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  746. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  747. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  748. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  749. """
  750. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  751. outputs = self.focalnet(
  752. pixel_values,
  753. output_hidden_states=output_hidden_states,
  754. return_dict=return_dict,
  755. )
  756. pooled_output = outputs[1]
  757. logits = self.classifier(pooled_output)
  758. loss = None
  759. if labels is not None:
  760. if self.config.problem_type is None:
  761. if self.num_labels == 1:
  762. self.config.problem_type = "regression"
  763. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  764. self.config.problem_type = "single_label_classification"
  765. else:
  766. self.config.problem_type = "multi_label_classification"
  767. if self.config.problem_type == "regression":
  768. loss_fct = MSELoss()
  769. if self.num_labels == 1:
  770. loss = loss_fct(logits.squeeze(), labels.squeeze())
  771. else:
  772. loss = loss_fct(logits, labels)
  773. elif self.config.problem_type == "single_label_classification":
  774. loss_fct = CrossEntropyLoss()
  775. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  776. elif self.config.problem_type == "multi_label_classification":
  777. loss_fct = BCEWithLogitsLoss()
  778. loss = loss_fct(logits, labels)
  779. if not return_dict:
  780. output = (logits,) + outputs[2:]
  781. return ((loss,) + output) if loss is not None else output
  782. return FocalNetImageClassifierOutput(
  783. loss=loss,
  784. logits=logits,
  785. hidden_states=outputs.hidden_states,
  786. reshaped_hidden_states=outputs.reshaped_hidden_states,
  787. )
  788. @add_start_docstrings(
  789. """
  790. FocalNet backbone, to be used with frameworks like X-Decoder.
  791. """,
  792. FOCALNET_START_DOCSTRING,
  793. )
  794. class FocalNetBackbone(FocalNetPreTrainedModel, BackboneMixin):
  795. def __init__(self, config: FocalNetConfig):
  796. super().__init__(config)
  797. super()._init_backbone(config)
  798. self.num_features = [config.embed_dim] + config.hidden_sizes
  799. self.focalnet = FocalNetModel(config)
  800. # initialize weights and apply final processing
  801. self.post_init()
  802. @add_start_docstrings_to_model_forward(FOCALNET_INPUTS_DOCSTRING)
  803. @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
  804. def forward(
  805. self,
  806. pixel_values: torch.Tensor,
  807. output_hidden_states: Optional[bool] = None,
  808. return_dict: Optional[bool] = None,
  809. ) -> BackboneOutput:
  810. """
  811. Returns:
  812. Examples:
  813. ```python
  814. >>> from transformers import AutoImageProcessor, AutoBackbone
  815. >>> import torch
  816. >>> from PIL import Image
  817. >>> import requests
  818. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  819. >>> image = Image.open(requests.get(url, stream=True).raw)
  820. >>> processor = AutoImageProcessor.from_pretrained("microsoft/focalnet-tiny-lrf")
  821. >>> model = AutoBackbone.from_pretrained("microsoft/focalnet-tiny-lrf")
  822. >>> inputs = processor(image, return_tensors="pt")
  823. >>> outputs = model(**inputs)
  824. ```"""
  825. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  826. output_hidden_states = (
  827. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  828. )
  829. outputs = self.focalnet(pixel_values, output_hidden_states=True, return_dict=True)
  830. hidden_states = outputs.reshaped_hidden_states
  831. feature_maps = ()
  832. for idx, stage in enumerate(self.stage_names):
  833. if stage in self.out_features:
  834. feature_maps += (hidden_states[idx],)
  835. if not return_dict:
  836. output = (feature_maps,)
  837. if output_hidden_states:
  838. output += (outputs.hidden_states,)
  839. return output
  840. return BackboneOutput(
  841. feature_maps=feature_maps,
  842. hidden_states=outputs.hidden_states if output_hidden_states else None,
  843. attentions=None,
  844. )