| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029 |
- # coding=utf-8
- # Copyright 2023 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch FocalNet model."""
- import collections.abc
- import math
- from dataclasses import dataclass
- from typing import Optional, Tuple, Union
- import torch
- import torch.utils.checkpoint
- from torch import nn
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
- from ...activations import ACT2FN
- from ...modeling_outputs import BackboneOutput
- from ...modeling_utils import PreTrainedModel
- from ...utils import (
- ModelOutput,
- add_code_sample_docstrings,
- add_start_docstrings,
- add_start_docstrings_to_model_forward,
- logging,
- replace_return_docstrings,
- )
- from ...utils.backbone_utils import BackboneMixin
- from .configuration_focalnet import FocalNetConfig
- logger = logging.get_logger(__name__)
- # General docstring
- _CONFIG_FOR_DOC = "FocalNetConfig"
- # Base docstring
- _CHECKPOINT_FOR_DOC = "microsoft/focalnet-tiny"
- _EXPECTED_OUTPUT_SHAPE = [1, 49, 768]
- # Image classification docstring
- _IMAGE_CLASS_CHECKPOINT = "microsoft/focalnet-tiny"
- _IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
- @dataclass
- class FocalNetEncoderOutput(ModelOutput):
- """
- FocalNet encoder's outputs, with potential hidden states.
- Args:
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- Sequence of hidden-states at the output of the last layer of the model.
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
- shape `(batch_size, sequence_length, hidden_size)`.
- Hidden-states of the model at the output of each layer plus the initial embedding outputs.
- reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
- shape `(batch_size, hidden_size, height, width)`.
- Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
- include the spatial dimensions.
- """
- last_hidden_state: torch.FloatTensor = None
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
- reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
- @dataclass
- class FocalNetModelOutput(ModelOutput):
- """
- FocalNet model's outputs that also contains a pooling of the last hidden states.
- Args:
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- Sequence of hidden-states at the output of the last layer of the model.
- pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
- Average pooling of the last layer hidden-state.
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
- shape `(batch_size, sequence_length, hidden_size)`.
- Hidden-states of the model at the output of each layer plus the initial embedding outputs.
- reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
- shape `(batch_size, hidden_size, height, width)`.
- Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
- include the spatial dimensions.
- """
- last_hidden_state: torch.FloatTensor = None
- pooler_output: Optional[torch.FloatTensor] = None
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
- reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
- @dataclass
- class FocalNetMaskedImageModelingOutput(ModelOutput):
- """
- FocalNet masked image model outputs.
- Args:
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
- Masked image modeling (MLM) loss.
- reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
- Reconstructed pixel values.
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
- shape `(batch_size, sequence_length, hidden_size)`.
- Hidden-states of the model at the output of each layer plus the initial embedding outputs.
- reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
- shape `(batch_size, hidden_size, height, width)`.
- Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
- include the spatial dimensions.
- """
- loss: Optional[torch.FloatTensor] = None
- reconstruction: torch.FloatTensor = None
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
- reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
- @dataclass
- class FocalNetImageClassifierOutput(ModelOutput):
- """
- FocalNet outputs for image classification.
- Args:
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
- Classification (or regression if config.num_labels==1) loss.
- logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
- Classification (or regression if config.num_labels==1) scores (before SoftMax).
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
- shape `(batch_size, sequence_length, hidden_size)`.
- Hidden-states of the model at the output of each layer plus the initial embedding outputs.
- reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
- shape `(batch_size, hidden_size, height, width)`.
- Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
- include the spatial dimensions.
- """
- loss: Optional[torch.FloatTensor] = None
- logits: torch.FloatTensor = None
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
- reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
- class FocalNetEmbeddings(nn.Module):
- """
- Construct the patch embeddings and layernorm. Optionally, also the mask token.
- """
- def __init__(self, config, use_mask_token=False):
- super().__init__()
- self.patch_embeddings = FocalNetPatchEmbeddings(
- config=config,
- image_size=config.image_size,
- patch_size=config.patch_size,
- num_channels=config.num_channels,
- embed_dim=config.embed_dim,
- use_conv_embed=config.use_conv_embed,
- is_stem=True,
- )
- self.patch_grid = self.patch_embeddings.grid_size
- self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None
- self.norm = nn.LayerNorm(config.embed_dim, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(
- self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None
- ) -> Tuple[torch.Tensor]:
- embeddings, output_dimensions = self.patch_embeddings(pixel_values)
- embeddings = self.norm(embeddings)
- batch_size, seq_len, _ = embeddings.size()
- if bool_masked_pos is not None:
- mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
- # replace the masked visual tokens by mask_tokens
- mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
- embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
- embeddings = self.dropout(embeddings)
- return embeddings, output_dimensions
- class FocalNetPatchEmbeddings(nn.Module):
- def __init__(
- self,
- config,
- image_size,
- patch_size,
- num_channels,
- embed_dim,
- add_norm=False,
- use_conv_embed=False,
- is_stem=False,
- ):
- super().__init__()
- image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
- patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
- num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
- self.image_size = image_size
- self.patch_size = patch_size
- self.num_channels = num_channels
- self.num_patches = num_patches
- self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
- if use_conv_embed:
- # if we choose to use conv embedding, then we treat the stem and non-stem differently
- if is_stem:
- kernel_size = 7
- padding = 2
- stride = 4
- else:
- kernel_size = 3
- padding = 1
- stride = 2
- self.projection = nn.Conv2d(
- num_channels, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
- )
- else:
- self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
- if add_norm:
- self.norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
- else:
- self.norm = None
- def maybe_pad(self, pixel_values, height, width):
- if width % self.patch_size[1] != 0:
- pad_values = (0, self.patch_size[1] - width % self.patch_size[1])
- pixel_values = nn.functional.pad(pixel_values, pad_values)
- if height % self.patch_size[0] != 0:
- pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0])
- pixel_values = nn.functional.pad(pixel_values, pad_values)
- return pixel_values
- def forward(self, pixel_values: Optional[torch.FloatTensor]) -> Tuple[torch.Tensor, Tuple[int]]:
- _, num_channels, height, width = pixel_values.shape
- if num_channels != self.num_channels:
- raise ValueError(
- "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
- )
- # pad the input to be divisible by self.patch_size, if needed
- pixel_values = self.maybe_pad(pixel_values, height, width)
- embeddings = self.projection(pixel_values)
- _, _, height, width = embeddings.shape
- output_dimensions = (height, width)
- embeddings = embeddings.flatten(2).transpose(1, 2)
- if self.norm is not None:
- embeddings = self.norm(embeddings)
- return embeddings, output_dimensions
- # Copied from transformers.models.beit.modeling_beit.drop_path
- def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
- """
- Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
- Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
- however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
- See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
- layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
- argument.
- """
- if drop_prob == 0.0 or not training:
- return input
- keep_prob = 1 - drop_prob
- shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
- random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
- random_tensor.floor_() # binarize
- output = input.div(keep_prob) * random_tensor
- return output
- # Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->FocalNet
- class FocalNetDropPath(nn.Module):
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
- def __init__(self, drop_prob: Optional[float] = None) -> None:
- super().__init__()
- self.drop_prob = drop_prob
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- return drop_path(hidden_states, self.drop_prob, self.training)
- def extra_repr(self) -> str:
- return "p={}".format(self.drop_prob)
- class FocalNetModulation(nn.Module):
- def __init__(self, config, index, dim, focal_factor=2, bias=True, projection_dropout=0.0):
- super().__init__()
- self.dim = dim
- self.focal_window = config.focal_windows[index]
- self.focal_level = config.focal_levels[index]
- self.focal_factor = focal_factor
- self.use_post_layernorm_in_modulation = config.use_post_layernorm_in_modulation
- self.normalize_modulator = config.normalize_modulator
- self.projection_in = nn.Linear(dim, 2 * dim + (self.focal_level + 1), bias=bias)
- self.projection_context = nn.Conv2d(dim, dim, kernel_size=1, stride=1, bias=bias)
- self.activation = nn.GELU()
- self.projection_out = nn.Linear(dim, dim)
- self.projection_dropout = nn.Dropout(projection_dropout)
- self.focal_layers = nn.ModuleList()
- self.kernel_sizes = []
- for k in range(self.focal_level):
- kernel_size = self.focal_factor * k + self.focal_window
- self.focal_layers.append(
- nn.Sequential(
- nn.Conv2d(
- dim, dim, kernel_size=kernel_size, stride=1, groups=dim, padding=kernel_size // 2, bias=False
- ),
- nn.GELU(),
- )
- )
- self.kernel_sizes.append(kernel_size)
- if self.use_post_layernorm_in_modulation:
- self.layernorm = nn.LayerNorm(dim, eps=config.layer_norm_eps)
- def forward(self, hidden_state):
- """
- Args:
- hidden_state:
- Input features with shape of (batch_size, height, width, num_channels)
- """
- num_channels = hidden_state.shape[-1]
- # pre linear projection
- x = self.projection_in(hidden_state).permute(0, 3, 1, 2).contiguous()
- q, ctx, self.gates = torch.split(x, (num_channels, num_channels, self.focal_level + 1), 1)
- # context aggreation
- ctx_all = 0
- for level in range(self.focal_level):
- ctx = self.focal_layers[level](ctx)
- ctx_all = ctx_all + ctx * self.gates[:, level : level + 1]
- ctx_global = self.activation(ctx.mean(2, keepdim=True).mean(3, keepdim=True))
- ctx_all = ctx_all + ctx_global * self.gates[:, self.focal_level :]
- # normalize context
- if self.normalize_modulator:
- ctx_all = ctx_all / (self.focal_level + 1)
- # focal modulation
- self.modulator = self.projection_context(ctx_all)
- x_out = q * self.modulator
- x_out = x_out.permute(0, 2, 3, 1).contiguous()
- if self.use_post_layernorm_in_modulation:
- x_out = self.layernorm(x_out)
- # post linear porjection
- x_out = self.projection_out(x_out)
- x_out = self.projection_dropout(x_out)
- return x_out
- class FocalNetMlp(nn.Module):
- def __init__(self, config, in_features, hidden_features=None, out_features=None, drop=0.0):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- self.fc1 = nn.Linear(in_features, hidden_features)
- self.activation = ACT2FN[config.hidden_act]
- self.fc2 = nn.Linear(hidden_features, out_features)
- self.drop = nn.Dropout(drop)
- def forward(self, hidden_state):
- hidden_state = self.fc1(hidden_state)
- hidden_state = self.activation(hidden_state)
- hidden_state = self.drop(hidden_state)
- hidden_state = self.fc2(hidden_state)
- hidden_state = self.drop(hidden_state)
- return hidden_state
- class FocalNetLayer(nn.Module):
- r"""Focal Modulation Network layer (block).
- Args:
- config (`FocalNetConfig`):
- Model config.
- index (`int`):
- Layer index.
- dim (`int`):
- Number of input channels.
- input_resolution (`Tuple[int]`):
- Input resulotion.
- drop_path (`float`, *optional*, defaults to 0.0):
- Stochastic depth rate.
- """
- def __init__(self, config, index, dim, input_resolution, drop_path=0.0):
- super().__init__()
- self.config = config
- # layer-specific attributes
- self.dim = dim
- self.input_resolution = input_resolution
- # general attributes
- self.drop = config.hidden_dropout_prob
- self.use_post_layernorm = config.use_post_layernorm
- self.norm1 = nn.LayerNorm(dim, eps=config.layer_norm_eps)
- self.modulation = FocalNetModulation(
- config=config,
- index=index,
- dim=dim,
- projection_dropout=self.drop,
- )
- self.drop_path = FocalNetDropPath(drop_path) if drop_path > 0.0 else nn.Identity()
- self.norm2 = nn.LayerNorm(dim, eps=config.layer_norm_eps)
- mlp_hidden_dim = int(dim * config.mlp_ratio)
- self.mlp = FocalNetMlp(config=config, in_features=dim, hidden_features=mlp_hidden_dim, drop=self.drop)
- self.gamma_1 = 1.0
- self.gamma_2 = 1.0
- if config.use_layerscale:
- self.gamma_1 = nn.Parameter(config.layerscale_value * torch.ones((dim)), requires_grad=True)
- self.gamma_2 = nn.Parameter(config.layerscale_value * torch.ones((dim)), requires_grad=True)
- def forward(self, hidden_state, input_dimensions):
- height, width = input_dimensions
- batch_size, _, num_channels = hidden_state.shape
- shortcut = hidden_state
- # Focal Modulation
- hidden_state = hidden_state if self.use_post_layernorm else self.norm1(hidden_state)
- hidden_state = hidden_state.view(batch_size, height, width, num_channels)
- hidden_state = self.modulation(hidden_state).view(batch_size, height * width, num_channels)
- hidden_state = hidden_state if not self.use_post_layernorm else self.norm1(hidden_state)
- # FFN
- hidden_state = shortcut + self.drop_path(self.gamma_1 * hidden_state)
- hidden_state = hidden_state + self.drop_path(
- self.gamma_2
- * (self.norm2(self.mlp(hidden_state)) if self.use_post_layernorm else self.mlp(self.norm2(hidden_state)))
- )
- return hidden_state
- class FocalNetStage(nn.Module):
- def __init__(self, config, index, input_resolution):
- super().__init__()
- self.config = config
- self.num_stages = len(config.depths)
- embed_dim = [config.embed_dim * (2**i) for i in range(self.num_stages)]
- dim = embed_dim[index]
- out_dim = embed_dim[index + 1] if (index < self.num_stages - 1) else None
- downsample = FocalNetPatchEmbeddings if (index < self.num_stages - 1) else None
- # stochastic depth decay rule
- dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
- drop_path = dpr[sum(config.depths[:index]) : sum(config.depths[: index + 1])]
- self.layers = nn.ModuleList(
- [
- FocalNetLayer(
- config=config,
- index=index,
- dim=dim,
- input_resolution=input_resolution,
- drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
- )
- for i in range(config.depths[index])
- ]
- )
- if downsample is not None:
- self.downsample = downsample(
- config=config,
- image_size=input_resolution,
- patch_size=2,
- num_channels=dim,
- embed_dim=out_dim,
- add_norm=True,
- use_conv_embed=config.use_conv_embed,
- is_stem=False,
- )
- else:
- self.downsample = None
- self.pointing = False
- def forward(self, hidden_states: torch.Tensor, input_dimensions: Tuple[int, int]) -> Tuple[torch.Tensor]:
- height, width = input_dimensions
- for layer_module in self.layers:
- hidden_states = layer_module(hidden_states, input_dimensions)
- hidden_states_before_downsampling = hidden_states
- if self.downsample is not None:
- height, width = input_dimensions
- hidden_states = hidden_states.transpose(1, 2).reshape(
- hidden_states_before_downsampling.shape[0], -1, height, width
- )
- hidden_states, output_dimensions = self.downsample(hidden_states)
- else:
- output_dimensions = (height, width, height, width)
- stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions)
- return stage_outputs
- class FocalNetEncoder(nn.Module):
- def __init__(self, config, grid_size):
- super().__init__()
- self.num_stages = len(config.depths)
- self.config = config
- self.stages = nn.ModuleList(
- [
- FocalNetStage(
- config=config,
- index=i_layer,
- input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)),
- )
- for i_layer in range(self.num_stages)
- ]
- )
- self.gradient_checkpointing = False
- def forward(
- self,
- hidden_states: torch.Tensor,
- input_dimensions: Tuple[int, int],
- output_hidden_states: Optional[bool] = False,
- output_hidden_states_before_downsampling: Optional[bool] = False,
- return_dict: Optional[bool] = True,
- ) -> Union[Tuple, FocalNetEncoderOutput]:
- all_hidden_states = () if output_hidden_states else None
- all_reshaped_hidden_states = () if output_hidden_states else None
- if output_hidden_states:
- batch_size, _, hidden_size = hidden_states.shape
- # rearrange b (h w) c -> b c h w
- reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
- reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
- all_hidden_states += (hidden_states,)
- all_reshaped_hidden_states += (reshaped_hidden_state,)
- for i, stage_module in enumerate(self.stages):
- if self.gradient_checkpointing and self.training:
- stage_outputs = self._gradient_checkpointing_func(
- stage_module.__call__,
- hidden_states,
- input_dimensions,
- )
- else:
- stage_outputs = stage_module(hidden_states, input_dimensions)
- hidden_states = stage_outputs[0]
- hidden_states_before_downsampling = stage_outputs[1]
- output_dimensions = stage_outputs[2]
- input_dimensions = (output_dimensions[-2], output_dimensions[-1])
- if output_hidden_states and output_hidden_states_before_downsampling:
- batch_size, _, hidden_size = hidden_states_before_downsampling.shape
- # rearrange b (h w) c -> b c h w
- # here we use the original (not downsampled) height and width
- reshaped_hidden_state = hidden_states_before_downsampling.view(
- batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size
- )
- reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
- all_hidden_states += (hidden_states_before_downsampling,)
- all_reshaped_hidden_states += (reshaped_hidden_state,)
- elif output_hidden_states and not output_hidden_states_before_downsampling:
- batch_size, _, hidden_size = hidden_states.shape
- # rearrange b (h w) c -> b c h w
- reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
- reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
- all_hidden_states += (hidden_states,)
- all_reshaped_hidden_states += (reshaped_hidden_state,)
- if not return_dict:
- return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
- return FocalNetEncoderOutput(
- last_hidden_state=hidden_states,
- hidden_states=all_hidden_states,
- reshaped_hidden_states=all_reshaped_hidden_states,
- )
- # Copied from transformers.models.swin.modeling_swin.SwinPreTrainedModel with Swin->FocalNet,swin->focalnet
- class FocalNetPreTrainedModel(PreTrainedModel):
- """
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
- models.
- """
- config_class = FocalNetConfig
- base_model_prefix = "focalnet"
- main_input_name = "pixel_values"
- supports_gradient_checkpointing = True
- _no_split_modules = ["FocalNetStage"]
- def _init_weights(self, module):
- """Initialize the weights"""
- if isinstance(module, (nn.Linear, nn.Conv2d)):
- # Slightly different from the TF version which uses truncated_normal for initialization
- # cf https://github.com/pytorch/pytorch/pull/5617
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
- if module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, nn.LayerNorm):
- module.bias.data.zero_()
- module.weight.data.fill_(1.0)
- FOCALNET_START_DOCSTRING = r"""
- This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
- it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
- behavior.
- Parameters:
- config ([`FocalNetConfig`]): Model configuration class with all the parameters of the model.
- Initializing with a config file does not load the weights associated with the model, only the
- configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
- """
- FOCALNET_INPUTS_DOCSTRING = r"""
- Args:
- pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
- Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
- [`AutoImageProcessor.__call__`] for details.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
- more detail.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- """
- @add_start_docstrings(
- "The bare FocalNet Model outputting raw hidden-states without any specific head on top.",
- FOCALNET_START_DOCSTRING,
- )
- class FocalNetModel(FocalNetPreTrainedModel):
- def __init__(self, config, add_pooling_layer=True, use_mask_token=False):
- super().__init__(config)
- self.config = config
- self.num_stages = len(config.depths)
- self.num_features = int(config.embed_dim * 2 ** (self.num_stages - 1))
- self.embeddings = FocalNetEmbeddings(config, use_mask_token=use_mask_token)
- self.encoder = FocalNetEncoder(config, self.embeddings.patch_grid)
- self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps)
- self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.embeddings.patch_embeddings
- @add_start_docstrings_to_model_forward(FOCALNET_INPUTS_DOCSTRING)
- @add_code_sample_docstrings(
- checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=FocalNetModelOutput,
- config_class=_CONFIG_FOR_DOC,
- modality="vision",
- expected_output=_EXPECTED_OUTPUT_SHAPE,
- )
- def forward(
- self,
- pixel_values: Optional[torch.FloatTensor] = None,
- bool_masked_pos: Optional[torch.BoolTensor] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, FocalNetModelOutput]:
- r"""
- bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
- Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
- """
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if pixel_values is None:
- raise ValueError("You have to specify pixel_values")
- embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
- encoder_outputs = self.encoder(
- embedding_output,
- input_dimensions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- sequence_output = encoder_outputs[0]
- sequence_output = self.layernorm(sequence_output)
- pooled_output = None
- if self.pooler is not None:
- pooled_output = self.pooler(sequence_output.transpose(1, 2))
- pooled_output = torch.flatten(pooled_output, 1)
- if not return_dict:
- output = (sequence_output, pooled_output) + encoder_outputs[1:]
- return output
- return FocalNetModelOutput(
- last_hidden_state=sequence_output,
- pooler_output=pooled_output,
- hidden_states=encoder_outputs.hidden_states,
- reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
- )
- @add_start_docstrings(
- """FocalNet Model with a decoder on top for masked image modeling.
- This follows the same implementation as in [SimMIM](https://arxiv.org/abs/2111.09886).
- <Tip>
- Note that we provide a script to pre-train this model on custom data in our [examples
- directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).
- </Tip>
- """,
- FOCALNET_START_DOCSTRING,
- )
- class FocalNetForMaskedImageModeling(FocalNetPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.focalnet = FocalNetModel(config, add_pooling_layer=False, use_mask_token=True)
- self.num_stages = len(config.depths)
- num_features = int(config.embed_dim * 2 ** (self.num_stages - 1))
- self.decoder = nn.Sequential(
- nn.Conv2d(
- in_channels=num_features, out_channels=config.encoder_stride**2 * config.num_channels, kernel_size=1
- ),
- nn.PixelShuffle(config.encoder_stride),
- )
- # Initialize weights and apply final processing
- self.post_init()
- @add_start_docstrings_to_model_forward(FOCALNET_INPUTS_DOCSTRING)
- @replace_return_docstrings(output_type=FocalNetMaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC)
- def forward(
- self,
- pixel_values: Optional[torch.FloatTensor] = None,
- bool_masked_pos: Optional[torch.BoolTensor] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, FocalNetMaskedImageModelingOutput]:
- r"""
- bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
- Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
- Returns:
- Examples:
- ```python
- >>> from transformers import AutoImageProcessor, FocalNetConfig, FocalNetForMaskedImageModeling
- >>> import torch
- >>> from PIL import Image
- >>> import requests
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
- >>> image = Image.open(requests.get(url, stream=True).raw)
- >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/focalnet-base-simmim-window6-192")
- >>> config = FocalNetConfig()
- >>> model = FocalNetForMaskedImageModeling(config)
- >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
- >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
- >>> # create random boolean mask of shape (batch_size, num_patches)
- >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
- >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
- >>> loss, reconstructed_pixel_values = outputs.loss, outputs.logits
- >>> list(reconstructed_pixel_values.shape)
- [1, 3, 192, 192]
- ```"""
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- outputs = self.focalnet(
- pixel_values,
- bool_masked_pos=bool_masked_pos,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- sequence_output = outputs[0]
- # Reshape to (batch_size, num_channels, height, width)
- sequence_output = sequence_output.transpose(1, 2)
- batch_size, num_channels, sequence_length = sequence_output.shape
- height = width = math.floor(sequence_length**0.5)
- sequence_output = sequence_output.reshape(batch_size, num_channels, height, width)
- # Reconstruct pixel values
- reconstructed_pixel_values = self.decoder(sequence_output)
- masked_im_loss = None
- if bool_masked_pos is not None:
- size = self.config.image_size // self.config.patch_size
- bool_masked_pos = bool_masked_pos.reshape(-1, size, size)
- mask = (
- bool_masked_pos.repeat_interleave(self.config.patch_size, 1)
- .repeat_interleave(self.config.patch_size, 2)
- .unsqueeze(1)
- .contiguous()
- )
- reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none")
- masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels
- if not return_dict:
- output = (reconstructed_pixel_values,) + outputs[2:]
- return ((masked_im_loss,) + output) if masked_im_loss is not None else output
- return FocalNetMaskedImageModelingOutput(
- loss=masked_im_loss,
- reconstruction=reconstructed_pixel_values,
- hidden_states=outputs.hidden_states,
- reshaped_hidden_states=outputs.reshaped_hidden_states,
- )
- @add_start_docstrings(
- """
- FocalNet Model with an image classification head on top (a linear layer on top of the pooled output) e.g. for
- ImageNet.
- """,
- FOCALNET_START_DOCSTRING,
- )
- class FocalNetForImageClassification(FocalNetPreTrainedModel):
- # Copied from transformers.models.swin.modeling_swin.SwinForImageClassification.__init__ with Swin->FocalNet, swin->focalnet
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.focalnet = FocalNetModel(config)
- # Classifier head
- self.classifier = (
- nn.Linear(self.focalnet.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity()
- )
- # Initialize weights and apply final processing
- self.post_init()
- @add_start_docstrings_to_model_forward(FOCALNET_INPUTS_DOCSTRING)
- @add_code_sample_docstrings(
- checkpoint=_IMAGE_CLASS_CHECKPOINT,
- output_type=FocalNetImageClassifierOutput,
- config_class=_CONFIG_FOR_DOC,
- expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
- )
- def forward(
- self,
- pixel_values: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, FocalNetImageClassifierOutput]:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- """
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- outputs = self.focalnet(
- pixel_values,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- pooled_output = outputs[1]
- logits = self.classifier(pooled_output)
- loss = None
- if labels is not None:
- if self.config.problem_type is None:
- if self.num_labels == 1:
- self.config.problem_type = "regression"
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
- self.config.problem_type = "single_label_classification"
- else:
- self.config.problem_type = "multi_label_classification"
- if self.config.problem_type == "regression":
- loss_fct = MSELoss()
- if self.num_labels == 1:
- loss = loss_fct(logits.squeeze(), labels.squeeze())
- else:
- loss = loss_fct(logits, labels)
- elif self.config.problem_type == "single_label_classification":
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
- elif self.config.problem_type == "multi_label_classification":
- loss_fct = BCEWithLogitsLoss()
- loss = loss_fct(logits, labels)
- if not return_dict:
- output = (logits,) + outputs[2:]
- return ((loss,) + output) if loss is not None else output
- return FocalNetImageClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- reshaped_hidden_states=outputs.reshaped_hidden_states,
- )
- @add_start_docstrings(
- """
- FocalNet backbone, to be used with frameworks like X-Decoder.
- """,
- FOCALNET_START_DOCSTRING,
- )
- class FocalNetBackbone(FocalNetPreTrainedModel, BackboneMixin):
- def __init__(self, config: FocalNetConfig):
- super().__init__(config)
- super()._init_backbone(config)
- self.num_features = [config.embed_dim] + config.hidden_sizes
- self.focalnet = FocalNetModel(config)
- # initialize weights and apply final processing
- self.post_init()
- @add_start_docstrings_to_model_forward(FOCALNET_INPUTS_DOCSTRING)
- @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC)
- def forward(
- self,
- pixel_values: torch.Tensor,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> BackboneOutput:
- """
- Returns:
- Examples:
- ```python
- >>> from transformers import AutoImageProcessor, AutoBackbone
- >>> import torch
- >>> from PIL import Image
- >>> import requests
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
- >>> image = Image.open(requests.get(url, stream=True).raw)
- >>> processor = AutoImageProcessor.from_pretrained("microsoft/focalnet-tiny-lrf")
- >>> model = AutoBackbone.from_pretrained("microsoft/focalnet-tiny-lrf")
- >>> inputs = processor(image, return_tensors="pt")
- >>> outputs = model(**inputs)
- ```"""
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- outputs = self.focalnet(pixel_values, output_hidden_states=True, return_dict=True)
- hidden_states = outputs.reshaped_hidden_states
- feature_maps = ()
- for idx, stage in enumerate(self.stage_names):
- if stage in self.out_features:
- feature_maps += (hidden_states[idx],)
- if not return_dict:
- output = (feature_maps,)
- if output_hidden_states:
- output += (outputs.hidden_states,)
- return output
- return BackboneOutput(
- feature_maps=feature_maps,
- hidden_states=outputs.hidden_states if output_hidden_states else None,
- attentions=None,
- )
|