modeling_sam.py 63 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412
  1. # coding=utf-8
  2. # Copyright 2023 The Meta AI Authors and The HuggingFace Team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch SAM model."""
  16. import collections
  17. from dataclasses import dataclass
  18. from typing import Dict, List, Optional, Tuple, Union
  19. import numpy as np
  20. import torch
  21. import torch.nn.functional as F
  22. import torch.utils.checkpoint
  23. from torch import Tensor, nn
  24. from ...activations import ACT2FN
  25. from ...modeling_outputs import BaseModelOutput
  26. from ...modeling_utils import PreTrainedModel
  27. from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging
  28. from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig
  29. logger = logging.get_logger(__name__)
  30. _CONFIG_FOR_DOC = "SamConfig"
  31. _CHECKPOINT_FOR_DOC = "facebook/sam-vit-huge"
  32. @dataclass
  33. class SamVisionEncoderOutput(ModelOutput):
  34. """
  35. Base class for sam vision model's outputs that also contains image embeddings obtained by applying the projection
  36. layer to the pooler_output.
  37. Args:
  38. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
  39. The image embeddings obtained by applying the projection layer to the pooler_output.
  40. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  41. Sequence of hidden-states at the output of the last layer of the model.
  42. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  43. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  44. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  45. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  46. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  47. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  48. sequence_length)`.
  49. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  50. heads.
  51. """
  52. image_embeds: Optional[torch.FloatTensor] = None
  53. last_hidden_state: torch.FloatTensor = None
  54. hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
  55. attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
  56. @dataclass
  57. class SamImageSegmentationOutput(ModelOutput):
  58. """
  59. Base class for Segment-Anything model's output
  60. Args:
  61. iou_scores (`torch.FloatTensor` of shape `(batch_size, num_masks)`):
  62. The iou scores of the predicted masks.
  63. pred_masks (`torch.FloatTensor` of shape `(batch_size, num_masks, height, width)`):
  64. The predicted low resolutions masks. Needs to be post-processed by the processor
  65. vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  66. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  67. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  68. Hidden-states of the vision model at the output of each layer plus the optional initial embedding outputs.
  69. vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  70. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  71. sequence_length)`.
  72. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  73. heads.
  74. mask_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  75. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  76. sequence_length)`.
  77. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  78. heads.
  79. """
  80. iou_scores: torch.FloatTensor = None
  81. pred_masks: torch.FloatTensor = None
  82. vision_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
  83. vision_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
  84. mask_decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
  85. class SamPatchEmbeddings(nn.Module):
  86. """
  87. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  88. `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
  89. Transformer.
  90. """
  91. def __init__(self, config):
  92. super().__init__()
  93. image_size, patch_size = config.image_size, config.patch_size
  94. num_channels, hidden_size = config.num_channels, config.hidden_size
  95. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  96. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  97. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  98. self.image_size = image_size
  99. self.patch_size = patch_size
  100. self.num_channels = num_channels
  101. self.num_patches = num_patches
  102. self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
  103. def forward(self, pixel_values):
  104. batch_size, num_channels, height, width = pixel_values.shape
  105. if num_channels != self.num_channels:
  106. raise ValueError(
  107. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  108. )
  109. if height != self.image_size[0] or width != self.image_size[1]:
  110. raise ValueError(
  111. f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
  112. )
  113. embeddings = self.projection(pixel_values).permute(0, 2, 3, 1)
  114. return embeddings
  115. class SamMLPBlock(nn.Module):
  116. def __init__(self, config):
  117. super().__init__()
  118. self.lin1 = nn.Linear(config.hidden_size, config.mlp_dim)
  119. self.lin2 = nn.Linear(config.mlp_dim, config.hidden_size)
  120. self.act = ACT2FN[config.hidden_act]
  121. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  122. hidden_states = self.lin1(hidden_states)
  123. hidden_states = self.act(hidden_states)
  124. hidden_states = self.lin2(hidden_states)
  125. return hidden_states
  126. # Copied from transformers.models.convnext.modeling_convnext.ConvNextLayerNorm with ConvNext->Sam
  127. class SamLayerNorm(nn.Module):
  128. r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
  129. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
  130. width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
  131. """
  132. def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
  133. super().__init__()
  134. self.weight = nn.Parameter(torch.ones(normalized_shape))
  135. self.bias = nn.Parameter(torch.zeros(normalized_shape))
  136. self.eps = eps
  137. self.data_format = data_format
  138. if self.data_format not in ["channels_last", "channels_first"]:
  139. raise NotImplementedError(f"Unsupported data format: {self.data_format}")
  140. self.normalized_shape = (normalized_shape,)
  141. def forward(self, x: torch.Tensor) -> torch.Tensor:
  142. if self.data_format == "channels_last":
  143. x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
  144. elif self.data_format == "channels_first":
  145. input_dtype = x.dtype
  146. x = x.float()
  147. u = x.mean(1, keepdim=True)
  148. s = (x - u).pow(2).mean(1, keepdim=True)
  149. x = (x - u) / torch.sqrt(s + self.eps)
  150. x = x.to(dtype=input_dtype)
  151. x = self.weight[:, None, None] * x + self.bias[:, None, None]
  152. return x
  153. class SamAttention(nn.Module):
  154. """
  155. SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
  156. values.
  157. """
  158. def __init__(self, config, downsample_rate=None):
  159. super().__init__()
  160. self.hidden_size = config.hidden_size
  161. downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
  162. self.internal_dim = config.hidden_size // downsample_rate
  163. self.num_attention_heads = config.num_attention_heads
  164. if self.internal_dim % config.num_attention_heads != 0:
  165. raise ValueError("num_attention_heads must divide hidden_size.")
  166. self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
  167. self.k_proj = nn.Linear(self.hidden_size, self.internal_dim)
  168. self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
  169. self.out_proj = nn.Linear(self.internal_dim, self.hidden_size)
  170. def _separate_heads(self, hidden_states: Tensor, num_attention_heads: int) -> Tensor:
  171. batch, point_batch_size, n_tokens, channel = hidden_states.shape
  172. c_per_head = channel // num_attention_heads
  173. hidden_states = hidden_states.reshape(batch * point_batch_size, n_tokens, num_attention_heads, c_per_head)
  174. return hidden_states.transpose(1, 2)
  175. def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tensor:
  176. batch, n_heads, n_tokens, c_per_head = hidden_states.shape
  177. hidden_states = hidden_states.transpose(1, 2)
  178. return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)
  179. def forward(self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Tensor = None) -> Tensor:
  180. # Input projections
  181. query = self.q_proj(query)
  182. key = self.k_proj(key)
  183. value = self.v_proj(value)
  184. point_batch_size = query.shape[1]
  185. # Separate into heads
  186. query = self._separate_heads(query, self.num_attention_heads)
  187. key = self._separate_heads(key, self.num_attention_heads)
  188. value = self._separate_heads(value, self.num_attention_heads)
  189. # SamAttention
  190. _, _, _, c_per_head = query.shape
  191. attn = query @ key.permute(0, 1, 3, 2) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens
  192. attn = attn / (c_per_head**0.5)
  193. attn = torch.softmax(attn, dim=-1)
  194. if attention_similarity is not None:
  195. attn = attn + attention_similarity
  196. attn = torch.softmax(attn, dim=-1)
  197. # Get output
  198. out = attn @ value
  199. out = self._recombine_heads(out, point_batch_size)
  200. out = self.out_proj(out)
  201. return out
  202. class SamTwoWayAttentionBlock(nn.Module):
  203. def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False):
  204. """
  205. A transformer block with four layers:
  206. (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on
  207. sparse inputs (4) cross attention of dense inputs -> sparse inputs
  208. Arguments:
  209. config (`SamMaskDecoderConfig`):
  210. The configuration file used to instantiate the block
  211. attention_downsample_rate (*optionalk*, int, defaults to 2):
  212. The downsample ratio of the block used to reduce the inner dim of the attention.
  213. skip_first_layer_pe (*optional*, bool, defaults to `False`):
  214. Whether or not to skip the addition of the query_point_embedding on the first layer.
  215. """
  216. super().__init__()
  217. self.hidden_size = config.hidden_size
  218. self.layer_norm_eps = config.layer_norm_eps
  219. self.self_attn = SamAttention(config, downsample_rate=1)
  220. self.layer_norm1 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
  221. self.cross_attn_token_to_image = SamAttention(config, downsample_rate=attention_downsample_rate)
  222. self.layer_norm2 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
  223. self.mlp = SamMLPBlock(config)
  224. self.layer_norm3 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
  225. self.layer_norm4 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
  226. self.cross_attn_image_to_token = SamAttention(config, downsample_rate=attention_downsample_rate)
  227. self.skip_first_layer_pe = skip_first_layer_pe
  228. def forward(
  229. self,
  230. queries: Tensor,
  231. keys: Tensor,
  232. query_point_embedding: Tensor,
  233. key_point_embedding: Tensor,
  234. attention_similarity: Tensor,
  235. output_attentions: bool = False,
  236. ):
  237. # Self attention block
  238. if self.skip_first_layer_pe:
  239. queries = self.self_attn(query=queries, key=queries, value=queries)
  240. else:
  241. query = queries + query_point_embedding
  242. attn_out = self.self_attn(query=query, key=query, value=queries)
  243. queries = queries + attn_out
  244. queries = self.layer_norm1(queries)
  245. # Cross attention block, tokens attending to image embedding
  246. query = queries + query_point_embedding
  247. key = keys + key_point_embedding
  248. attn_out = self.cross_attn_token_to_image(
  249. query=query, key=key, value=keys, attention_similarity=attention_similarity
  250. )
  251. queries = queries + attn_out
  252. queries = self.layer_norm2(queries)
  253. # MLP block
  254. mlp_out = self.mlp(queries)
  255. queries = queries + mlp_out
  256. queries = self.layer_norm3(queries)
  257. # Cross attention block, image embedding attending to tokens
  258. query = queries + query_point_embedding
  259. key = keys + key_point_embedding
  260. attn_out = self.cross_attn_image_to_token(query=key, key=query, value=queries)
  261. keys = keys + attn_out
  262. keys = self.layer_norm4(keys)
  263. outputs = (queries, keys)
  264. if output_attentions:
  265. outputs = outputs + (attn_out,)
  266. else:
  267. outputs = outputs + (None,)
  268. return outputs
  269. class SamTwoWayTransformer(nn.Module):
  270. def __init__(self, config: SamMaskDecoderConfig):
  271. super().__init__()
  272. self.config = config
  273. self.num_hidden_layers = config.num_hidden_layers
  274. self.layers = nn.ModuleList()
  275. for i in range(self.num_hidden_layers):
  276. self.layers.append(SamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0)))
  277. self.final_attn_token_to_image = SamAttention(config)
  278. self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size)
  279. def forward(
  280. self,
  281. point_embeddings: Tensor,
  282. image_embeddings: Tensor,
  283. image_positional_embeddings: Tensor,
  284. attention_similarity: Tensor,
  285. target_embedding=None,
  286. output_attentions: Optional[bool] = None,
  287. output_hidden_states: Optional[bool] = None,
  288. return_dict: Optional[bool] = None,
  289. ) -> Union[Tuple, BaseModelOutput]:
  290. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  291. output_hidden_states = (
  292. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  293. )
  294. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  295. all_attentions = ()
  296. if image_embeddings is None:
  297. raise ValueError("You have to specify an image_embedding")
  298. image_embeddings = image_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
  299. image_positional_embeddings = image_positional_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
  300. # Prepare queries
  301. queries = point_embeddings
  302. keys = image_embeddings
  303. # Apply transformer blocks and final layernorm
  304. for layer in self.layers:
  305. if target_embedding is not None:
  306. queries += target_embedding
  307. queries, keys, attention_outputs = layer(
  308. queries=queries,
  309. keys=keys,
  310. query_point_embedding=point_embeddings,
  311. key_point_embedding=image_positional_embeddings,
  312. attention_similarity=attention_similarity,
  313. output_attentions=output_attentions,
  314. )
  315. if output_attentions:
  316. all_attentions = all_attentions + (attention_outputs,)
  317. # Apply the final attenion layer from the points to the image
  318. query = queries + point_embeddings
  319. key = keys + image_positional_embeddings
  320. attn_out = self.final_attn_token_to_image(query=query, key=key, value=keys)
  321. queries = queries + attn_out
  322. queries = self.layer_norm_final_attn(queries)
  323. return queries, keys, all_attentions
  324. class SamFeedForward(nn.Module):
  325. def __init__(
  326. self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False
  327. ):
  328. super().__init__()
  329. self.num_layers = num_layers
  330. self.activation = nn.ReLU()
  331. self.proj_in = nn.Linear(input_dim, hidden_dim)
  332. self.proj_out = nn.Linear(hidden_dim, output_dim)
  333. self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)])
  334. self.sigmoid_output = sigmoid_output
  335. def forward(self, hidden_states):
  336. hidden_states = self.proj_in(hidden_states)
  337. hidden_states = self.activation(hidden_states)
  338. for layer in self.layers:
  339. hidden_states = self.activation(layer(hidden_states))
  340. hidden_states = self.proj_out(hidden_states)
  341. if self.sigmoid_output:
  342. hidden_states = F.sigmoid(hidden_states)
  343. return hidden_states
  344. class SamMaskDecoder(nn.Module):
  345. def __init__(self, config: SamMaskDecoderConfig):
  346. super().__init__()
  347. self.hidden_size = config.hidden_size
  348. self.num_multimask_outputs = config.num_multimask_outputs
  349. self.num_mask_tokens = config.num_multimask_outputs + 1
  350. self.iou_token = nn.Embedding(1, self.hidden_size)
  351. self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size)
  352. self.transformer = SamTwoWayTransformer(config)
  353. # should we create a new class for this?
  354. self.upscale_conv1 = nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 4, kernel_size=2, stride=2)
  355. self.upscale_conv2 = nn.ConvTranspose2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=2, stride=2)
  356. self.upscale_layer_norm = SamLayerNorm(self.hidden_size // 4, data_format="channels_first")
  357. self.activation = nn.GELU()
  358. mlps_list = []
  359. for _ in range(self.num_mask_tokens):
  360. mlps_list += [SamFeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3)]
  361. self.output_hypernetworks_mlps = nn.ModuleList(mlps_list)
  362. self.iou_prediction_head = SamFeedForward(
  363. self.hidden_size, config.iou_head_hidden_dim, self.num_mask_tokens, config.iou_head_depth
  364. )
  365. def forward(
  366. self,
  367. image_embeddings: torch.Tensor,
  368. image_positional_embeddings: torch.Tensor,
  369. sparse_prompt_embeddings: torch.Tensor,
  370. dense_prompt_embeddings: torch.Tensor,
  371. multimask_output: bool,
  372. output_attentions: Optional[bool] = None,
  373. attention_similarity: torch.Tensor = None,
  374. target_embedding: torch.Tensor = None,
  375. ) -> Tuple[torch.Tensor, torch.Tensor]:
  376. """
  377. Predict masks given image and prompt embeddings.
  378. Args:
  379. image_embeddings (`torch.Tensor`):
  380. the embeddings from the image encoder
  381. image_positional_embedding (`torch.Tensor`):
  382. positional encoding with the shape of image_embeddings
  383. sparse_prompt_embeddings (`torch.Tensor`):
  384. The embeddings of the points and boxes
  385. dense_prompt_embeddings (`torch.Tensor`):
  386. the embeddings of the mask inputs
  387. multimask_output (bool):
  388. Whether to return multiple masks or a single mask.
  389. output_attentions (bool, *optional*):
  390. Whether or not to return the attentions tensors of all attention layers.
  391. """
  392. batch_size, num_channels, height, width = image_embeddings.shape
  393. point_batch_size = sparse_prompt_embeddings.shape[1]
  394. # Concatenate output tokens
  395. output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
  396. output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
  397. if sparse_prompt_embeddings.sum().item() != 0:
  398. tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
  399. else:
  400. tokens = output_tokens
  401. point_embeddings = tokens.to(self.iou_token.weight.dtype)
  402. # Expand per-image data in batch direction to be per-point
  403. image_embeddings = image_embeddings + dense_prompt_embeddings
  404. image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0)
  405. image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0)
  406. # Run the transformer, image_positional_embedding are consumed
  407. point_embedding, image_embeddings, attentions = self.transformer(
  408. point_embeddings=point_embeddings,
  409. image_embeddings=image_embeddings,
  410. image_positional_embeddings=image_positional_embeddings,
  411. attention_similarity=attention_similarity,
  412. target_embedding=target_embedding,
  413. output_attentions=output_attentions,
  414. )
  415. iou_token_out = point_embedding[:, :, 0, :]
  416. mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :]
  417. # Upscale mask embeddings and predict masks using the mask tokens
  418. image_embeddings = image_embeddings.transpose(2, 3).reshape(
  419. batch_size * point_batch_size, num_channels, height, width
  420. )
  421. upscaled_embedding = self.upscale_conv1(image_embeddings)
  422. upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
  423. upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding))
  424. hyper_in_list = []
  425. for i in range(self.num_mask_tokens):
  426. current_mlp = self.output_hypernetworks_mlps[i]
  427. hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
  428. hyper_in = torch.stack(hyper_in_list, dim=2)
  429. _, num_channels, height, width = upscaled_embedding.shape
  430. upscaled_embedding = upscaled_embedding.reshape(batch_size, point_batch_size, num_channels, height * width)
  431. masks = (hyper_in @ upscaled_embedding).reshape(batch_size, point_batch_size, -1, height, width)
  432. # Generate mask quality predictions
  433. iou_pred = self.iou_prediction_head(iou_token_out)
  434. # Select the correct mask or masks for output
  435. if multimask_output:
  436. mask_slice = slice(1, None)
  437. else:
  438. mask_slice = slice(0, 1)
  439. masks = masks[:, :, mask_slice, :, :]
  440. iou_pred = iou_pred[:, :, mask_slice]
  441. outputs = (masks, iou_pred)
  442. if output_attentions:
  443. outputs = outputs + (attentions,)
  444. else:
  445. outputs = outputs + (None,)
  446. return outputs
  447. class SamPositionalEmbedding(nn.Module):
  448. def __init__(self, config):
  449. super().__init__()
  450. self.scale = config.hidden_size // 2
  451. self.register_buffer("positional_embedding", self.scale * torch.randn((2, config.num_pos_feats)))
  452. def forward(self, input_coords, input_shape=None):
  453. """Positionally encode points that are normalized to [0,1]."""
  454. coordinates = input_coords.clone()
  455. if input_shape is not None:
  456. coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1]
  457. coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0]
  458. # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
  459. coordinates = 2 * coordinates - 1
  460. coordinates = coordinates.to(self.positional_embedding.dtype)
  461. coordinates = coordinates @ self.positional_embedding
  462. coordinates = 2 * np.pi * coordinates
  463. # outputs d_1 x ... x d_n x channel shape
  464. return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1)
  465. class SamMaskEmbedding(nn.Module):
  466. def __init__(self, config: SamPromptEncoderConfig):
  467. super().__init__()
  468. self.mask_input_channels = config.mask_input_channels // 4
  469. self.activation = ACT2FN[config.hidden_act]
  470. self.conv1 = nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2)
  471. self.conv2 = nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2)
  472. self.conv3 = nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1)
  473. self.layer_norm1 = SamLayerNorm(
  474. self.mask_input_channels, eps=config.layer_norm_eps, data_format="channels_first"
  475. )
  476. self.layer_norm2 = SamLayerNorm(
  477. self.mask_input_channels * 4, eps=config.layer_norm_eps, data_format="channels_first"
  478. )
  479. def forward(self, masks):
  480. hidden_states = self.conv1(masks)
  481. hidden_states = self.layer_norm1(hidden_states)
  482. hidden_states = self.activation(hidden_states)
  483. hidden_states = self.conv2(hidden_states)
  484. hidden_states = self.layer_norm2(hidden_states)
  485. hidden_states = self.activation(hidden_states)
  486. dense_embeddings = self.conv3(hidden_states)
  487. return dense_embeddings
  488. class SamPromptEncoder(nn.Module):
  489. def __init__(self, config: SamPromptEncoderConfig, shared_patch_embedding):
  490. super().__init__()
  491. self.shared_embedding = shared_patch_embedding
  492. self.mask_embed = SamMaskEmbedding(config)
  493. self.no_mask_embed = nn.Embedding(1, config.hidden_size)
  494. self.image_embedding_size = (config.image_embedding_size, config.image_embedding_size)
  495. self.input_image_size = config.image_size
  496. self.point_embed = nn.ModuleList(
  497. [nn.Embedding(1, config.hidden_size) for i in range(config.num_point_embeddings)]
  498. )
  499. self.hidden_size = config.hidden_size
  500. self.not_a_point_embed = nn.Embedding(1, config.hidden_size)
  501. def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
  502. """Embeds point prompts."""
  503. points = points + 0.5 # Shift to center of pixel
  504. if pad:
  505. target_point_shape = (points.shape[0], points.shape[1], 1, points.shape[-1])
  506. target_labels_shape = (points.shape[0], points.shape[1], 1)
  507. padding_point = torch.zeros(target_point_shape, device=points.device)
  508. padding_label = -torch.ones(target_labels_shape, device=labels.device)
  509. points = torch.cat([points, padding_point], dim=2)
  510. labels = torch.cat([labels, padding_label], dim=2)
  511. input_shape = (self.input_image_size, self.input_image_size)
  512. point_embedding = self.shared_embedding(points, input_shape)
  513. # torch.where and expanding the labels tensor is required by the ONNX export
  514. point_embedding = torch.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding)
  515. # This is required for the ONNX export. The dtype, device need to be explicitely
  516. # specificed as otherwise torch.onnx.export interprets as double
  517. point_embedding = torch.where(
  518. labels[..., None] != -10,
  519. point_embedding,
  520. torch.tensor(0.0, dtype=point_embedding.dtype, device=point_embedding.device),
  521. )
  522. point_embedding = torch.where(
  523. (labels == 0)[:, :, :, None],
  524. point_embedding + self.point_embed[0].weight[None, None, :, :],
  525. point_embedding,
  526. )
  527. point_embedding = torch.where(
  528. (labels == 1)[:, :, :, None],
  529. point_embedding + self.point_embed[1].weight[None, None, :, :],
  530. point_embedding,
  531. )
  532. return point_embedding
  533. def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
  534. """Embeds box prompts."""
  535. boxes = boxes + 0.5 # Shift to center of pixel
  536. batch_size, nb_boxes = boxes.shape[:2]
  537. coords = boxes.reshape(batch_size, nb_boxes, 2, 2)
  538. input_shape = (self.input_image_size, self.input_image_size)
  539. corner_embedding = self.shared_embedding(coords, input_shape)
  540. corner_embedding[:, :, 0, :] += self.point_embed[2].weight
  541. corner_embedding[:, :, 1, :] += self.point_embed[3].weight
  542. return corner_embedding
  543. def forward(
  544. self,
  545. input_points: Optional[Tuple[torch.Tensor, torch.Tensor]],
  546. input_labels: Optional[torch.Tensor],
  547. input_boxes: Optional[torch.Tensor],
  548. input_masks: Optional[torch.Tensor],
  549. ) -> Tuple[torch.Tensor, torch.Tensor]:
  550. """
  551. Embeds different types of prompts, returning both sparse and dense embeddings.
  552. Args:
  553. points (`torch.Tensor`, *optional*):
  554. point coordinates and labels to embed.
  555. boxes (`torch.Tensor`, *optional*):
  556. boxes to embed
  557. masks (`torch.Tensor`, *optional*):
  558. masks to embed
  559. """
  560. sparse_embeddings = None
  561. batch_size = 1
  562. target_device = self.shared_embedding.positional_embedding.device
  563. if input_points is not None:
  564. batch_size, point_batch_size = input_points.shape[:2]
  565. if input_labels is None:
  566. raise ValueError("If points are provided, labels must also be provided.")
  567. point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
  568. sparse_embeddings = point_embeddings
  569. if input_boxes is not None:
  570. batch_size = input_boxes.shape[0]
  571. box_embeddings = self._embed_boxes(input_boxes)
  572. if sparse_embeddings is None:
  573. sparse_embeddings = box_embeddings
  574. else:
  575. sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=2)
  576. if input_masks is not None:
  577. dense_embeddings = self.mask_embed(input_masks)
  578. else:
  579. dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
  580. batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1]
  581. )
  582. if sparse_embeddings is None:
  583. sparse_embeddings = torch.zeros((batch_size, 1, 1, self.hidden_size), device=target_device)
  584. return sparse_embeddings, dense_embeddings
  585. class SamVisionAttention(nn.Module):
  586. """Multi-head Attention block with relative position embeddings."""
  587. def __init__(self, config, window_size):
  588. super().__init__()
  589. input_size = (
  590. (config.image_size // config.patch_size, config.image_size // config.patch_size)
  591. if window_size == 0
  592. else (window_size, window_size)
  593. )
  594. self.num_attention_heads = config.num_attention_heads
  595. head_dim = config.hidden_size // config.num_attention_heads
  596. self.scale = head_dim**-0.5
  597. self.dropout = config.attention_dropout
  598. self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.qkv_bias)
  599. self.proj = nn.Linear(config.hidden_size, config.hidden_size)
  600. self.use_rel_pos = config.use_rel_pos
  601. if self.use_rel_pos:
  602. if input_size is None:
  603. raise ValueError("Input size must be provided if using relative positional encoding.")
  604. # initialize relative positional embeddings
  605. self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
  606. self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
  607. def get_rel_pos(self, q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
  608. """
  609. Get relative positional embeddings according to the relative positions of
  610. query and key sizes.
  611. Args:
  612. q_size (int):
  613. size of the query.
  614. k_size (int):
  615. size of key k.
  616. rel_pos (`torch.Tensor`):
  617. relative position embeddings (L, channel).
  618. Returns:
  619. Extracted positional embeddings according to relative positions.
  620. """
  621. max_rel_dist = int(2 * max(q_size, k_size) - 1)
  622. # Interpolate rel pos.
  623. rel_pos_resized = F.interpolate(
  624. rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
  625. size=max_rel_dist,
  626. mode="linear",
  627. )
  628. rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
  629. # Scale the coords with short length if shapes for q and k are different.
  630. q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
  631. k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
  632. relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
  633. return rel_pos_resized[relative_coords.long()]
  634. def add_decomposed_rel_pos(
  635. self,
  636. attn: torch.Tensor,
  637. query: torch.Tensor,
  638. rel_pos_h: torch.Tensor,
  639. rel_pos_w: torch.Tensor,
  640. q_size: Tuple[int, int],
  641. k_size: Tuple[int, int],
  642. ) -> torch.Tensor:
  643. """
  644. Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
  645. https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
  646. Args:
  647. attn (`torch.Tensor`):
  648. attention map.
  649. query (`torch.Tensor`):
  650. query q in the attention layer with shape (batch_size, query_height * query_width, channel).
  651. rel_pos_h (`torch.Tensor`):
  652. relative position embeddings (Lh, channel) for height axis.
  653. rel_pos_w (`torch.Tensor`):
  654. relative position embeddings (Lw, channel) for width axis.
  655. q_size (tuple):
  656. spatial sequence size of query q with (query_height, query_width).
  657. k_size (tuple):
  658. spatial sequence size of key k with (key_height, key_width).
  659. Returns:
  660. attn (`torch.Tensor`):
  661. attention map with added relative positional embeddings.
  662. """
  663. query_height, query_width = q_size
  664. key_height, key_width = k_size
  665. relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h)
  666. relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w)
  667. batch_size, _, dim = query.shape
  668. reshaped_query = query.reshape(batch_size, query_height, query_width, dim)
  669. rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
  670. rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
  671. attn = attn.reshape(batch_size, query_height, query_width, key_height, key_width)
  672. attn = attn + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
  673. attn = attn.reshape(batch_size, query_height * query_width, key_height * key_width)
  674. return attn
  675. def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
  676. batch_size, height, width, _ = hidden_states.shape
  677. # qkv with shape (3, batch_size, nHead, height * width, channel)
  678. qkv = (
  679. self.qkv(hidden_states)
  680. .reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
  681. .permute(2, 0, 3, 1, 4)
  682. )
  683. # q, k, v with shape (batch_size * nHead, height * width, channel)
  684. query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(0)
  685. attn_weights = (query * self.scale) @ key.transpose(-2, -1)
  686. if self.use_rel_pos:
  687. attn_weights = self.add_decomposed_rel_pos(
  688. attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
  689. )
  690. attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
  691. attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  692. attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
  693. attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
  694. attn_output = self.proj(attn_output)
  695. if output_attentions:
  696. outputs = (attn_output, attn_weights)
  697. else:
  698. outputs = (attn_output, None)
  699. return outputs
  700. class SamVisionLayer(nn.Module):
  701. def __init__(self, config, window_size):
  702. super().__init__()
  703. self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  704. self.attn = SamVisionAttention(config, window_size)
  705. self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  706. self.mlp = SamMLPBlock(config)
  707. self.window_size = window_size
  708. def window_partition(self, hidden_states: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
  709. """
  710. Args:
  711. Partition into non-overlapping windows with padding if needed.
  712. hidden_states (tensor): input tokens with [batch_size, height, width, channel]. window_size (int): window
  713. size.
  714. Returns:
  715. windows: windows after partition with [batch_size * num_windows, window_size, window_size, channel].
  716. (pad_height, pad_width): padded height and width before partition
  717. """
  718. batch_size, height, width, channel = hidden_states.shape
  719. pad_h = (window_size - height % window_size) % window_size
  720. pad_w = (window_size - width % window_size) % window_size
  721. hidden_states = F.pad(hidden_states, (0, 0, 0, pad_w, 0, pad_h))
  722. pad_height, pad_width = height + pad_h, width + pad_w
  723. hidden_states = hidden_states.reshape(
  724. batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel
  725. )
  726. windows = hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(-1, window_size, window_size, channel)
  727. return windows, (pad_height, pad_width)
  728. def window_unpartition(
  729. self, windows: torch.Tensor, window_size: int, padding_shape: Tuple[int, int], original_shape: Tuple[int, int]
  730. ) -> torch.Tensor:
  731. """
  732. Args:
  733. Window unpartition into original sequences and removing padding.
  734. hidden_states (tensor):
  735. input tokens with [batch_size * num_windows, window_size, window_size, channel].
  736. window_size (int):
  737. window size.
  738. padding_shape (Tuple):
  739. padded height and width (pad_height, pad_width).
  740. original_shape (Tuple): original height and width (height, width) before padding.
  741. Returns:
  742. hidden_states: unpartitioned sequences with [batch_size, height, width, channel].
  743. """
  744. pad_height, pad_width = padding_shape
  745. height, width = original_shape
  746. batch_size = windows.shape[0] // (pad_height * pad_width // window_size // window_size)
  747. hidden_states = windows.reshape(
  748. batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1
  749. )
  750. hidden_states = (
  751. hidden_states.permute(0, 1, 3, 2, 4, 5).contiguous().reshape(batch_size, pad_height, pad_width, -1)
  752. )
  753. hidden_states = hidden_states[:, :height, :width, :].contiguous()
  754. return hidden_states
  755. def forward(
  756. self,
  757. hidden_states: torch.Tensor,
  758. output_attentions: Optional[bool] = False,
  759. ) -> Tuple[torch.FloatTensor]:
  760. residual = hidden_states
  761. hidden_states = self.layer_norm1(hidden_states)
  762. # Window partition
  763. if self.window_size > 0:
  764. height, width = hidden_states.shape[1], hidden_states.shape[2]
  765. hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size)
  766. hidden_states, attn_weights = self.attn(
  767. hidden_states=hidden_states,
  768. output_attentions=output_attentions,
  769. )
  770. # Reverse window partition
  771. if self.window_size > 0:
  772. hidden_states = self.window_unpartition(hidden_states, self.window_size, padding_shape, (height, width))
  773. hidden_states = residual + hidden_states
  774. layernorm_output = self.layer_norm2(hidden_states)
  775. hidden_states = hidden_states + self.mlp(layernorm_output)
  776. outputs = (hidden_states,)
  777. if output_attentions:
  778. outputs += (attn_weights,)
  779. return outputs
  780. class SamVisionNeck(nn.Module):
  781. def __init__(self, config: SamVisionConfig):
  782. super().__init__()
  783. self.config = config
  784. self.conv1 = nn.Conv2d(config.hidden_size, config.output_channels, kernel_size=1, bias=False)
  785. self.layer_norm1 = SamLayerNorm(config.output_channels, data_format="channels_first")
  786. self.conv2 = nn.Conv2d(config.output_channels, config.output_channels, kernel_size=3, padding=1, bias=False)
  787. self.layer_norm2 = SamLayerNorm(config.output_channels, data_format="channels_first")
  788. def forward(self, hidden_states):
  789. hidden_states = hidden_states.permute(0, 3, 1, 2)
  790. hidden_states = self.conv1(hidden_states)
  791. hidden_states = self.layer_norm1(hidden_states)
  792. hidden_states = self.conv2(hidden_states)
  793. hidden_states = self.layer_norm2(hidden_states)
  794. return hidden_states
  795. class SamVisionEncoder(nn.Module):
  796. def __init__(self, config: SamVisionConfig):
  797. super().__init__()
  798. self.config = config
  799. self.image_size = config.image_size
  800. self.patch_embed = SamPatchEmbeddings(config)
  801. self.pos_embed = None
  802. if config.use_abs_pos:
  803. # Initialize absolute positional embedding with pretrain image size.
  804. self.pos_embed = nn.Parameter(
  805. torch.zeros(
  806. 1,
  807. config.image_size // config.patch_size,
  808. config.image_size // config.patch_size,
  809. config.hidden_size,
  810. )
  811. )
  812. self.layers = nn.ModuleList()
  813. for i in range(config.num_hidden_layers):
  814. layer = SamVisionLayer(
  815. config,
  816. window_size=config.window_size if i not in config.global_attn_indexes else 0,
  817. )
  818. self.layers.append(layer)
  819. self.neck = SamVisionNeck(config)
  820. self.gradient_checkpointing = False
  821. def get_input_embeddings(self):
  822. return self.patch_embed
  823. def forward(
  824. self,
  825. pixel_values: Optional[torch.FloatTensor] = None,
  826. output_attentions: Optional[bool] = None,
  827. output_hidden_states: Optional[bool] = None,
  828. return_dict: Optional[bool] = None,
  829. ) -> Union[Tuple, SamVisionEncoderOutput]:
  830. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  831. output_hidden_states = (
  832. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  833. )
  834. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  835. if pixel_values is None:
  836. raise ValueError("You have to specify pixel_values")
  837. hidden_states = self.patch_embed(pixel_values)
  838. if self.pos_embed is not None:
  839. hidden_states = hidden_states + self.pos_embed
  840. all_hidden_states = () if output_hidden_states else None
  841. all_self_attentions = () if output_attentions else None
  842. for i, layer_module in enumerate(self.layers):
  843. if output_hidden_states:
  844. all_hidden_states = all_hidden_states + (hidden_states,)
  845. if self.gradient_checkpointing and self.training:
  846. layer_outputs = self._gradient_checkpointing_func(
  847. layer_module.__call__,
  848. hidden_states,
  849. )
  850. else:
  851. layer_outputs = layer_module(hidden_states, output_attentions=output_attentions)
  852. hidden_states = layer_outputs[0]
  853. if output_attentions:
  854. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  855. if output_hidden_states:
  856. all_hidden_states = all_hidden_states + (hidden_states,)
  857. hidden_states = self.neck(hidden_states)
  858. if not return_dict:
  859. outputs = (hidden_states,)
  860. if output_hidden_states:
  861. outputs = outputs + (all_hidden_states,)
  862. if output_attentions:
  863. outputs = outputs + (all_self_attentions,)
  864. return outputs
  865. return SamVisionEncoderOutput(
  866. last_hidden_state=hidden_states,
  867. hidden_states=all_hidden_states,
  868. attentions=all_self_attentions,
  869. )
  870. class SamPreTrainedModel(PreTrainedModel):
  871. config_class = SamConfig
  872. base_model_prefix = "sam"
  873. main_input_name = "pixel_values"
  874. _no_split_modules = ["SamVisionAttention"]
  875. def _init_weights(self, module):
  876. std = self.config.initializer_range
  877. if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
  878. module.weight.data.normal_(mean=0.0, std=std)
  879. if module.bias is not None:
  880. module.bias.data.zero_()
  881. elif isinstance(module, nn.Embedding):
  882. module.weight.data.normal_(mean=0.0, std=std)
  883. if module.padding_idx is not None:
  884. module.weight.data[module.padding_idx].zero_()
  885. SAM_START_DOCSTRING = r"""
  886. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  887. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  888. etc.)
  889. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  890. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  891. and behavior.
  892. Parameters:
  893. config ([`SamConfig`]): Model configuration class with all the parameters of the model.
  894. Initializing with a config file does not load the weights associated with the model, only the
  895. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  896. """
  897. SAM_INPUTS_DOCSTRING = r"""
  898. Args:
  899. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  900. Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for
  901. details.
  902. input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`):
  903. Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much
  904. better results. The points can be obtained by passing a list of list of list to the processor that will
  905. create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the
  906. second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict
  907. per input point), the third dimension is the number of points per segmentation mask (it is possible to pass
  908. multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal)
  909. coordinates of the point. If a different number of points is passed either for each image, or for each
  910. mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the
  911. computation of the embedding will be skipped for these points using the labels.
  912. input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`):
  913. Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the
  914. official implementation, there are 3 types of labels
  915. - `1`: the point is a point that contains the object of interest
  916. - `0`: the point is a point that does not contain the object of interest
  917. - `-1`: the point corresponds to the background
  918. We added the label:
  919. - `-10`: the point is a padding point, thus should be ignored by the prompt encoder
  920. The padding labels should be automatically done by the processor.
  921. input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`):
  922. Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to
  923. much better generated masks. The boxes can be obtained by passing a list of list of list to the processor,
  924. that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch
  925. size, the number of boxes per image and the coordinates of the top left and botton right point of the box.
  926. In the order (`x1`, `y1`, `x2`, `y2`):
  927. - `x1`: the x coordinate of the top left point of the input box
  928. - `y1`: the y coordinate of the top left point of the input box
  929. - `x2`: the x coordinate of the bottom right point of the input box
  930. - `y2`: the y coordinate of the bottom right point of the input box
  931. input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`):
  932. SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to
  933. generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be
  934. manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`).
  935. image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`):
  936. Image embeddings, this is used by the mask decder to generate masks and iou scores. For more memory
  937. efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings`
  938. method, and then feed them to the `forward` method instead of feeding the `pixel_values`.
  939. multimask_output (`bool`, *optional*):
  940. In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
  941. bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
  942. "best" mask, by specifying `multimask_output=False`.
  943. attention_similarity (`torch.FloatTensor`, *optional*):
  944. Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the
  945. model is used for personalization as introduced in [PerSAM](https://arxiv.org/abs/2305.03048).
  946. target_embedding (`torch.FloatTensor`, *optional*):
  947. Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case
  948. the model is used for personalization as introduced in [PerSAM](https://arxiv.org/abs/2305.03048).
  949. output_attentions (`bool`, *optional*):
  950. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  951. tensors for more detail.
  952. output_hidden_states (`bool`, *optional*):
  953. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  954. more detail.
  955. return_dict (`bool`, *optional*):
  956. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  957. """
  958. @add_start_docstrings(
  959. "Segment Anything Model (SAM) for generating segmentation masks, given an input image and ",
  960. " optional 2D location and bounding boxes.",
  961. SAM_START_DOCSTRING,
  962. )
  963. class SamModel(SamPreTrainedModel):
  964. _tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"]
  965. def __init__(self, config):
  966. super().__init__(config)
  967. self.shared_image_embedding = SamPositionalEmbedding(config.vision_config)
  968. self.vision_encoder = SamVisionEncoder(config.vision_config)
  969. self.prompt_encoder = SamPromptEncoder(config.prompt_encoder_config, self.shared_image_embedding)
  970. self.mask_decoder = SamMaskDecoder(config.mask_decoder_config)
  971. self.post_init()
  972. def get_input_embeddings(self):
  973. return self.vision_encoder.get_input_embeddings()
  974. def get_image_wide_positional_embeddings(self):
  975. size = self.config.prompt_encoder_config.image_embedding_size
  976. target_device = self.shared_image_embedding.positional_embedding.device
  977. target_dtype = self.shared_image_embedding.positional_embedding.dtype
  978. grid = torch.ones((size, size), device=target_device, dtype=target_dtype)
  979. y_embed = grid.cumsum(dim=0) - 0.5
  980. x_embed = grid.cumsum(dim=1) - 0.5
  981. y_embed = y_embed / size
  982. x_embed = x_embed / size
  983. positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1))
  984. return positional_embedding.permute(2, 0, 1).unsqueeze(0) # channel x height x width
  985. @torch.no_grad()
  986. def get_image_embeddings(
  987. self,
  988. pixel_values,
  989. output_attentions: Optional[bool] = None,
  990. output_hidden_states: Optional[bool] = None,
  991. return_dict: Optional[bool] = None,
  992. ):
  993. r"""
  994. Returns the image embeddings by passing the pixel values through the vision encoder.
  995. Args:
  996. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  997. Input pixel values
  998. output_attentions (`bool`, *optional*):
  999. Whether or not to return the attentions tensors of all attention layers.
  1000. output_hidden_states (`bool`, *optional*):
  1001. Whether or not to return the hidden states of all layers.
  1002. return_dict (`bool`, *optional*):
  1003. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  1004. """
  1005. vision_output = self.vision_encoder(
  1006. pixel_values,
  1007. output_attentions=output_attentions,
  1008. output_hidden_states=output_hidden_states,
  1009. return_dict=return_dict,
  1010. )
  1011. image_embeddings = vision_output[0]
  1012. return image_embeddings
  1013. @torch.no_grad()
  1014. def get_prompt_embeddings(
  1015. self,
  1016. input_points: Optional[torch.FloatTensor] = None,
  1017. input_labels: Optional[torch.LongTensor] = None,
  1018. input_boxes: Optional[torch.FloatTensor] = None,
  1019. input_masks: Optional[torch.LongTensor] = None,
  1020. ):
  1021. r"""
  1022. Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder.
  1023. Args:
  1024. input_points (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`):
  1025. Optional input points for the prompt encoder. The padding of the point is automatically done by the
  1026. processor. `point_batch_size` refers to the number of masks that we want the model to predict per
  1027. point. The model will output `point_batch_size` times 3 masks in total.
  1028. input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points_per_image)`):
  1029. Optional input labels for the prompt encoder. The padding of the labels is automatically done by the
  1030. processor, or can be fed by the user.
  1031. input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes_per_image, 4)`):
  1032. Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the
  1033. processor. users can also pass manually the input boxes.
  1034. input_masks (`torch.LongTensor` of shape `(batch_size, image_size, image_size)`):
  1035. Optional input masks for the prompt encoder.
  1036. """
  1037. prompt_output = self.prompt_encoder(
  1038. input_points=input_points,
  1039. input_labels=input_labels,
  1040. input_boxes=input_boxes,
  1041. input_masks=input_masks,
  1042. )
  1043. return prompt_output
  1044. @add_start_docstrings_to_model_forward(SAM_INPUTS_DOCSTRING)
  1045. def forward(
  1046. self,
  1047. pixel_values: Optional[torch.FloatTensor] = None,
  1048. input_points: Optional[torch.FloatTensor] = None,
  1049. input_labels: Optional[torch.LongTensor] = None,
  1050. input_boxes: Optional[torch.FloatTensor] = None,
  1051. input_masks: Optional[torch.LongTensor] = None,
  1052. image_embeddings: Optional[torch.FloatTensor] = None,
  1053. multimask_output: bool = True,
  1054. attention_similarity: Optional[torch.FloatTensor] = None,
  1055. target_embedding: Optional[torch.FloatTensor] = None,
  1056. output_attentions: Optional[bool] = None,
  1057. output_hidden_states: Optional[bool] = None,
  1058. return_dict: Optional[bool] = None,
  1059. **kwargs,
  1060. ) -> List[Dict[str, torch.Tensor]]:
  1061. r"""
  1062. Example:
  1063. ```python
  1064. >>> from PIL import Image
  1065. >>> import requests
  1066. >>> from transformers import AutoModel, AutoProcessor
  1067. >>> model = AutoModel.from_pretrained("facebook/sam-vit-base")
  1068. >>> processor = AutoProcessor.from_pretrained("facebook/sam-vit-base")
  1069. >>> img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png"
  1070. >>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
  1071. >>> input_points = [[[400, 650]]] # 2D location of a window on the car
  1072. >>> inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt")
  1073. >>> # Get segmentation mask
  1074. >>> outputs = model(**inputs)
  1075. >>> # Postprocess masks
  1076. >>> masks = processor.post_process_masks(
  1077. ... outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"]
  1078. ... )
  1079. ```
  1080. """
  1081. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1082. output_hidden_states = (
  1083. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1084. )
  1085. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1086. if pixel_values is None and image_embeddings is None:
  1087. raise ValueError("Either pixel_values or image_embeddings must be provided.")
  1088. if pixel_values is not None and image_embeddings is not None:
  1089. raise ValueError("Only one of pixel_values and image_embeddings can be provided.")
  1090. if input_points is not None and len(input_points.shape) != 4:
  1091. raise ValueError(
  1092. "The input_points must be a 4D tensor. Of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.",
  1093. " got {}.".format(input_points.shape),
  1094. )
  1095. if input_boxes is not None and len(input_boxes.shape) != 3:
  1096. raise ValueError(
  1097. "The input_points must be a 3D tensor. Of shape `batch_size`, `nb_boxes`, `4`.",
  1098. " got {}.".format(input_boxes.shape),
  1099. )
  1100. if input_points is not None and input_boxes is not None:
  1101. point_batch_size = input_points.shape[1]
  1102. box_batch_size = input_boxes.shape[1]
  1103. if point_batch_size != box_batch_size:
  1104. raise ValueError(
  1105. "You should provide as many bounding boxes as input points per box. Got {} and {}.".format(
  1106. point_batch_size, box_batch_size
  1107. )
  1108. )
  1109. image_positional_embeddings = self.get_image_wide_positional_embeddings()
  1110. # repeat with batch size
  1111. batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings.shape[0]
  1112. image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
  1113. vision_attentions = None
  1114. vision_hidden_states = None
  1115. if pixel_values is not None:
  1116. vision_outputs = self.vision_encoder(
  1117. pixel_values,
  1118. output_attentions=output_attentions,
  1119. output_hidden_states=output_hidden_states,
  1120. return_dict=return_dict,
  1121. )
  1122. image_embeddings = vision_outputs[0]
  1123. if output_hidden_states:
  1124. vision_hidden_states = vision_outputs[1]
  1125. if output_attentions:
  1126. vision_attentions = vision_outputs[-1]
  1127. if input_points is not None and input_labels is None:
  1128. input_labels = torch.ones_like(input_points[:, :, :, 0], dtype=torch.int, device=input_points.device)
  1129. if input_points is not None and image_embeddings.shape[0] != input_points.shape[0]:
  1130. raise ValueError(
  1131. "The batch size of the image embeddings and the input points must be the same. ",
  1132. "Got {} and {} respectively.".format(image_embeddings.shape[0], input_points.shape[0]),
  1133. " if you want to pass multiple points for the same image, make sure that you passed ",
  1134. " input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ",
  1135. " input_labels of shape (batch_size, point_batch_size, num_points_per_image)",
  1136. )
  1137. sparse_embeddings, dense_embeddings = self.prompt_encoder(
  1138. input_points=input_points,
  1139. input_labels=input_labels,
  1140. input_boxes=input_boxes,
  1141. input_masks=input_masks,
  1142. )
  1143. low_res_masks, iou_predictions, mask_decoder_attentions = self.mask_decoder(
  1144. image_embeddings=image_embeddings,
  1145. image_positional_embeddings=image_positional_embeddings,
  1146. sparse_prompt_embeddings=sparse_embeddings,
  1147. dense_prompt_embeddings=dense_embeddings,
  1148. multimask_output=multimask_output,
  1149. attention_similarity=attention_similarity,
  1150. target_embedding=target_embedding,
  1151. output_attentions=output_attentions,
  1152. )
  1153. if not return_dict:
  1154. output = (iou_predictions, low_res_masks)
  1155. if output_hidden_states:
  1156. output = output + (vision_hidden_states,)
  1157. if output_attentions:
  1158. output = output + (vision_attentions, mask_decoder_attentions)
  1159. return output
  1160. return SamImageSegmentationOutput(
  1161. iou_scores=iou_predictions,
  1162. pred_masks=low_res_masks,
  1163. vision_hidden_states=vision_hidden_states,
  1164. vision_attentions=vision_attentions,
  1165. mask_decoder_attentions=mask_decoder_attentions,
  1166. )