modeling_clipseg.py 65 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510
  1. # coding=utf-8
  2. # Copyright 2022 The OpenAI Team Authors and The HuggingFace Team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch CLIPSeg model."""
  16. import copy
  17. import math
  18. from dataclasses import dataclass
  19. from typing import Any, Optional, Tuple, Union
  20. import torch
  21. import torch.utils.checkpoint
  22. from torch import nn
  23. from ...activations import ACT2FN
  24. from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
  25. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
  26. from ...modeling_utils import PreTrainedModel
  27. from ...utils import (
  28. ModelOutput,
  29. add_start_docstrings,
  30. add_start_docstrings_to_model_forward,
  31. logging,
  32. replace_return_docstrings,
  33. torch_int,
  34. )
  35. from .configuration_clipseg import CLIPSegConfig, CLIPSegTextConfig, CLIPSegVisionConfig
  36. logger = logging.get_logger(__name__)
  37. _CHECKPOINT_FOR_DOC = "CIDAS/clipseg-rd64-refined"
  38. # contrastive loss function, adapted from
  39. # https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
  40. def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
  41. return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
  42. # Copied from transformers.models.clip.modeling_clip.clip_loss with clip->clipseg
  43. def clipseg_loss(similarity: torch.Tensor) -> torch.Tensor:
  44. caption_loss = contrastive_loss(similarity)
  45. image_loss = contrastive_loss(similarity.t())
  46. return (caption_loss + image_loss) / 2.0
  47. @dataclass
  48. # Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->CLIPSeg
  49. class CLIPSegOutput(ModelOutput):
  50. """
  51. Args:
  52. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
  53. Contrastive loss for image-text similarity.
  54. logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
  55. The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
  56. similarity scores.
  57. logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
  58. The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
  59. similarity scores.
  60. text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  61. The text embeddings obtained by applying the projection layer to the pooled output of [`CLIPSegTextModel`].
  62. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
  63. The image embeddings obtained by applying the projection layer to the pooled output of [`CLIPSegVisionModel`].
  64. text_model_output (`BaseModelOutputWithPooling`):
  65. The output of the [`CLIPSegTextModel`].
  66. vision_model_output (`BaseModelOutputWithPooling`):
  67. The output of the [`CLIPSegVisionModel`].
  68. """
  69. loss: Optional[torch.FloatTensor] = None
  70. logits_per_image: torch.FloatTensor = None
  71. logits_per_text: torch.FloatTensor = None
  72. text_embeds: torch.FloatTensor = None
  73. image_embeds: torch.FloatTensor = None
  74. text_model_output: BaseModelOutputWithPooling = None
  75. vision_model_output: BaseModelOutputWithPooling = None
  76. def to_tuple(self) -> Tuple[Any]:
  77. return tuple(
  78. self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
  79. for k in self.keys()
  80. )
  81. @dataclass
  82. class CLIPSegDecoderOutput(ModelOutput):
  83. """
  84. Args:
  85. logits (`torch.FloatTensor` of shape `(batch_size, height, width)`):
  86. Classification scores for each pixel.
  87. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  88. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  89. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  90. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  91. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  92. sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
  93. the self-attention heads.
  94. """
  95. logits: torch.FloatTensor = None
  96. hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  97. attentions: Optional[Tuple[torch.FloatTensor]] = None
  98. @dataclass
  99. class CLIPSegImageSegmentationOutput(ModelOutput):
  100. """
  101. Args:
  102. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
  103. Contrastive loss for image-text similarity.
  104. ...
  105. vision_model_output (`BaseModelOutputWithPooling`):
  106. The output of the [`CLIPSegVisionModel`].
  107. """
  108. loss: Optional[torch.FloatTensor] = None
  109. logits: torch.FloatTensor = None
  110. conditional_embeddings: torch.FloatTensor = None
  111. pooled_output: torch.FloatTensor = None
  112. vision_model_output: BaseModelOutputWithPooling = None
  113. decoder_output: CLIPSegDecoderOutput = None
  114. def to_tuple(self) -> Tuple[Any]:
  115. return tuple(
  116. self[k] if k not in ["vision_model_output", "decoder_output"] else getattr(self, k).to_tuple()
  117. for k in self.keys()
  118. )
  119. class CLIPSegVisionEmbeddings(nn.Module):
  120. # Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings.__init__ with CLIP->CLIPSeg
  121. def __init__(self, config: CLIPSegVisionConfig):
  122. super().__init__()
  123. self.config = config
  124. self.embed_dim = config.hidden_size
  125. self.image_size = config.image_size
  126. self.patch_size = config.patch_size
  127. self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
  128. self.patch_embedding = nn.Conv2d(
  129. in_channels=config.num_channels,
  130. out_channels=self.embed_dim,
  131. kernel_size=self.patch_size,
  132. stride=self.patch_size,
  133. bias=False,
  134. )
  135. self.num_patches = (self.image_size // self.patch_size) ** 2
  136. self.num_positions = self.num_patches + 1
  137. self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
  138. self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
  139. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  140. """
  141. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  142. images. This method is also adapted to support torch.jit tracing.
  143. Adapted from:
  144. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  145. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  146. """
  147. num_patches = embeddings.shape[1] - 1
  148. position_embedding = self.position_embedding.weight.unsqueeze(0)
  149. num_positions = position_embedding.shape[1] - 1
  150. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  151. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  152. return self.position_embedding(self.position_ids)
  153. class_pos_embed = position_embedding[:, :1]
  154. patch_pos_embed = position_embedding[:, 1:]
  155. dim = embeddings.shape[-1]
  156. new_height = height // self.patch_size
  157. new_width = width // self.patch_size
  158. sqrt_num_positions = torch_int(num_positions**0.5)
  159. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  160. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  161. patch_pos_embed = nn.functional.interpolate(
  162. patch_pos_embed,
  163. size=(new_height, new_width),
  164. mode="bicubic",
  165. align_corners=False,
  166. )
  167. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  168. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  169. def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
  170. batch_size, _, height, width = pixel_values.shape
  171. if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
  172. raise ValueError(
  173. f"Input image size ({height}*{width}) doesn't match model" f" ({self.image_size}*{self.image_size})."
  174. )
  175. patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
  176. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  177. class_embeds = self.class_embedding.expand(batch_size, 1, -1)
  178. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  179. if interpolate_pos_encoding:
  180. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  181. else:
  182. embeddings = embeddings + self.position_embedding(self.position_ids)
  183. return embeddings
  184. # Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->CLIPSeg
  185. class CLIPSegTextEmbeddings(nn.Module):
  186. def __init__(self, config: CLIPSegTextConfig):
  187. super().__init__()
  188. embed_dim = config.hidden_size
  189. self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
  190. self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
  191. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  192. self.register_buffer(
  193. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  194. )
  195. def forward(
  196. self,
  197. input_ids: Optional[torch.LongTensor] = None,
  198. position_ids: Optional[torch.LongTensor] = None,
  199. inputs_embeds: Optional[torch.FloatTensor] = None,
  200. ) -> torch.Tensor:
  201. seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
  202. if position_ids is None:
  203. position_ids = self.position_ids[:, :seq_length]
  204. if inputs_embeds is None:
  205. inputs_embeds = self.token_embedding(input_ids)
  206. position_embeddings = self.position_embedding(position_ids)
  207. embeddings = inputs_embeds + position_embeddings
  208. return embeddings
  209. # Copied from transformers.models.clip.modeling_clip.CLIPAttention with CLIP->CLIPSeg
  210. class CLIPSegAttention(nn.Module):
  211. """Multi-headed attention from 'Attention Is All You Need' paper"""
  212. def __init__(self, config):
  213. super().__init__()
  214. self.config = config
  215. self.embed_dim = config.hidden_size
  216. self.num_heads = config.num_attention_heads
  217. self.head_dim = self.embed_dim // self.num_heads
  218. if self.head_dim * self.num_heads != self.embed_dim:
  219. raise ValueError(
  220. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  221. f" {self.num_heads})."
  222. )
  223. self.scale = self.head_dim**-0.5
  224. self.dropout = config.attention_dropout
  225. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  226. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  227. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  228. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
  229. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  230. return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
  231. def forward(
  232. self,
  233. hidden_states: torch.Tensor,
  234. attention_mask: Optional[torch.Tensor] = None,
  235. causal_attention_mask: Optional[torch.Tensor] = None,
  236. output_attentions: Optional[bool] = False,
  237. ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
  238. """Input shape: Batch x Time x Channel"""
  239. bsz, tgt_len, embed_dim = hidden_states.size()
  240. # get query proj
  241. query_states = self.q_proj(hidden_states) * self.scale
  242. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  243. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  244. proj_shape = (bsz * self.num_heads, -1, self.head_dim)
  245. query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
  246. key_states = key_states.view(*proj_shape)
  247. value_states = value_states.view(*proj_shape)
  248. src_len = key_states.size(1)
  249. attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
  250. if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
  251. raise ValueError(
  252. f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
  253. f" {attn_weights.size()}"
  254. )
  255. # apply the causal_attention_mask first
  256. if causal_attention_mask is not None:
  257. if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
  258. raise ValueError(
  259. f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
  260. f" {causal_attention_mask.size()}"
  261. )
  262. attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
  263. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  264. if attention_mask is not None:
  265. if attention_mask.size() != (bsz, 1, tgt_len, src_len):
  266. raise ValueError(
  267. f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
  268. )
  269. attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
  270. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  271. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  272. if output_attentions:
  273. # this operation is a bit akward, but it's required to
  274. # make sure that attn_weights keeps its gradient.
  275. # In order to do so, attn_weights have to reshaped
  276. # twice and have to be reused in the following
  277. attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  278. attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
  279. else:
  280. attn_weights_reshaped = None
  281. attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  282. attn_output = torch.bmm(attn_probs, value_states)
  283. if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
  284. raise ValueError(
  285. f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
  286. f" {attn_output.size()}"
  287. )
  288. attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
  289. attn_output = attn_output.transpose(1, 2)
  290. attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
  291. attn_output = self.out_proj(attn_output)
  292. return attn_output, attn_weights_reshaped
  293. # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->CLIPSeg
  294. class CLIPSegMLP(nn.Module):
  295. def __init__(self, config):
  296. super().__init__()
  297. self.config = config
  298. self.activation_fn = ACT2FN[config.hidden_act]
  299. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  300. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  301. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  302. hidden_states = self.fc1(hidden_states)
  303. hidden_states = self.activation_fn(hidden_states)
  304. hidden_states = self.fc2(hidden_states)
  305. return hidden_states
  306. # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->CLIPSeg
  307. class CLIPSegEncoderLayer(nn.Module):
  308. def __init__(self, config: CLIPSegConfig):
  309. super().__init__()
  310. self.embed_dim = config.hidden_size
  311. self.self_attn = CLIPSegAttention(config)
  312. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  313. self.mlp = CLIPSegMLP(config)
  314. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  315. def forward(
  316. self,
  317. hidden_states: torch.Tensor,
  318. attention_mask: torch.Tensor,
  319. causal_attention_mask: torch.Tensor,
  320. output_attentions: Optional[bool] = False,
  321. ) -> Tuple[torch.FloatTensor]:
  322. """
  323. Args:
  324. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  325. attention_mask (`torch.FloatTensor`): attention mask of size
  326. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  327. `(config.encoder_attention_heads,)`.
  328. output_attentions (`bool`, *optional*):
  329. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  330. returned tensors for more detail.
  331. """
  332. residual = hidden_states
  333. hidden_states = self.layer_norm1(hidden_states)
  334. hidden_states, attn_weights = self.self_attn(
  335. hidden_states=hidden_states,
  336. attention_mask=attention_mask,
  337. causal_attention_mask=causal_attention_mask,
  338. output_attentions=output_attentions,
  339. )
  340. hidden_states = residual + hidden_states
  341. residual = hidden_states
  342. hidden_states = self.layer_norm2(hidden_states)
  343. hidden_states = self.mlp(hidden_states)
  344. hidden_states = residual + hidden_states
  345. outputs = (hidden_states,)
  346. if output_attentions:
  347. outputs += (attn_weights,)
  348. return outputs
  349. class CLIPSegPreTrainedModel(PreTrainedModel):
  350. """
  351. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  352. models.
  353. """
  354. config_class = CLIPSegConfig
  355. base_model_prefix = "clip"
  356. supports_gradient_checkpointing = True
  357. def _init_weights(self, module):
  358. """Initialize the weights"""
  359. factor = self.config.initializer_factor
  360. if isinstance(module, CLIPSegTextEmbeddings):
  361. module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
  362. module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
  363. elif isinstance(module, CLIPSegVisionEmbeddings):
  364. factor = self.config.initializer_factor
  365. nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
  366. nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
  367. nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
  368. elif isinstance(module, CLIPSegAttention):
  369. factor = self.config.initializer_factor
  370. in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
  371. out_proj_std = (module.embed_dim**-0.5) * factor
  372. nn.init.normal_(module.q_proj.weight, std=in_proj_std)
  373. nn.init.normal_(module.k_proj.weight, std=in_proj_std)
  374. nn.init.normal_(module.v_proj.weight, std=in_proj_std)
  375. nn.init.normal_(module.out_proj.weight, std=out_proj_std)
  376. elif isinstance(module, CLIPSegMLP):
  377. factor = self.config.initializer_factor
  378. in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
  379. fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
  380. nn.init.normal_(module.fc1.weight, std=fc_std)
  381. nn.init.normal_(module.fc2.weight, std=in_proj_std)
  382. elif isinstance(module, CLIPSegModel):
  383. nn.init.normal_(
  384. module.text_projection.weight,
  385. std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
  386. )
  387. nn.init.normal_(
  388. module.visual_projection.weight,
  389. std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
  390. )
  391. if isinstance(module, nn.LayerNorm):
  392. module.bias.data.zero_()
  393. module.weight.data.fill_(1.0)
  394. if isinstance(module, nn.Linear) and module.bias is not None:
  395. module.bias.data.zero_()
  396. CLIPSEG_START_DOCSTRING = r"""
  397. This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
  398. as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
  399. behavior.
  400. Parameters:
  401. config ([`CLIPSegConfig`]): Model configuration class with all the parameters of the model.
  402. Initializing with a config file does not load the weights associated with the model, only the
  403. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  404. """
  405. CLIPSEG_TEXT_INPUTS_DOCSTRING = r"""
  406. Args:
  407. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  408. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
  409. it.
  410. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  411. [`PreTrainedTokenizer.__call__`] for details.
  412. [What are input IDs?](../glossary#input-ids)
  413. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  414. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  415. - 1 for tokens that are **not masked**,
  416. - 0 for tokens that are **masked**.
  417. [What are attention masks?](../glossary#attention-mask)
  418. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  419. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  420. config.max_position_embeddings - 1]`.
  421. [What are position IDs?](../glossary#position-ids)
  422. output_attentions (`bool`, *optional*):
  423. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  424. tensors for more detail.
  425. output_hidden_states (`bool`, *optional*):
  426. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  427. more detail.
  428. return_dict (`bool`, *optional*):
  429. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  430. """
  431. CLIPSEG_VISION_INPUTS_DOCSTRING = r"""
  432. Args:
  433. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  434. Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
  435. [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
  436. output_attentions (`bool`, *optional*):
  437. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  438. tensors for more detail.
  439. output_hidden_states (`bool`, *optional*):
  440. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  441. more detail.
  442. interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
  443. Whether to interpolate the pre-trained position encodings.
  444. return_dict (`bool`, *optional*):
  445. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  446. """
  447. CLIPSEG_INPUTS_DOCSTRING = r"""
  448. Args:
  449. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  450. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
  451. it.
  452. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  453. [`PreTrainedTokenizer.__call__`] for details.
  454. [What are input IDs?](../glossary#input-ids)
  455. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  456. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  457. - 1 for tokens that are **not masked**,
  458. - 0 for tokens that are **masked**.
  459. [What are attention masks?](../glossary#attention-mask)
  460. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  461. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  462. config.max_position_embeddings - 1]`.
  463. [What are position IDs?](../glossary#position-ids)
  464. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  465. Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
  466. [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
  467. return_loss (`bool`, *optional*):
  468. Whether or not to return the contrastive loss.
  469. output_attentions (`bool`, *optional*):
  470. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  471. tensors for more detail.
  472. output_hidden_states (`bool`, *optional*):
  473. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  474. more detail.
  475. interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
  476. Whether to interpolate the pre-trained position encodings.
  477. return_dict (`bool`, *optional*):
  478. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  479. """
  480. # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->CLIPSeg
  481. class CLIPSegEncoder(nn.Module):
  482. """
  483. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  484. [`CLIPSegEncoderLayer`].
  485. Args:
  486. config: CLIPSegConfig
  487. """
  488. def __init__(self, config: CLIPSegConfig):
  489. super().__init__()
  490. self.config = config
  491. self.layers = nn.ModuleList([CLIPSegEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  492. self.gradient_checkpointing = False
  493. def forward(
  494. self,
  495. inputs_embeds,
  496. attention_mask: Optional[torch.Tensor] = None,
  497. causal_attention_mask: Optional[torch.Tensor] = None,
  498. output_attentions: Optional[bool] = None,
  499. output_hidden_states: Optional[bool] = None,
  500. return_dict: Optional[bool] = None,
  501. ) -> Union[Tuple, BaseModelOutput]:
  502. r"""
  503. Args:
  504. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  505. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  506. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  507. than the model's internal embedding lookup matrix.
  508. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  509. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  510. - 1 for tokens that are **not masked**,
  511. - 0 for tokens that are **masked**.
  512. [What are attention masks?](../glossary#attention-mask)
  513. causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  514. Causal mask for the text model. Mask values selected in `[0, 1]`:
  515. - 1 for tokens that are **not masked**,
  516. - 0 for tokens that are **masked**.
  517. [What are attention masks?](../glossary#attention-mask)
  518. output_attentions (`bool`, *optional*):
  519. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  520. returned tensors for more detail.
  521. output_hidden_states (`bool`, *optional*):
  522. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  523. for more detail.
  524. return_dict (`bool`, *optional*):
  525. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  526. """
  527. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  528. output_hidden_states = (
  529. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  530. )
  531. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  532. encoder_states = () if output_hidden_states else None
  533. all_attentions = () if output_attentions else None
  534. hidden_states = inputs_embeds
  535. for idx, encoder_layer in enumerate(self.layers):
  536. if output_hidden_states:
  537. encoder_states = encoder_states + (hidden_states,)
  538. if self.gradient_checkpointing and self.training:
  539. layer_outputs = self._gradient_checkpointing_func(
  540. encoder_layer.__call__,
  541. hidden_states,
  542. attention_mask,
  543. causal_attention_mask,
  544. output_attentions,
  545. )
  546. else:
  547. layer_outputs = encoder_layer(
  548. hidden_states,
  549. attention_mask,
  550. causal_attention_mask,
  551. output_attentions=output_attentions,
  552. )
  553. hidden_states = layer_outputs[0]
  554. if output_attentions:
  555. all_attentions = all_attentions + (layer_outputs[1],)
  556. if output_hidden_states:
  557. encoder_states = encoder_states + (hidden_states,)
  558. if not return_dict:
  559. return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
  560. return BaseModelOutput(
  561. last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
  562. )
  563. class CLIPSegTextTransformer(nn.Module):
  564. def __init__(self, config: CLIPSegTextConfig):
  565. super().__init__()
  566. self.config = config
  567. embed_dim = config.hidden_size
  568. self.embeddings = CLIPSegTextEmbeddings(config)
  569. self.encoder = CLIPSegEncoder(config)
  570. self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  571. # For `pooled_output` computation
  572. self.eos_token_id = config.eos_token_id
  573. @add_start_docstrings_to_model_forward(CLIPSEG_TEXT_INPUTS_DOCSTRING)
  574. @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPSegTextConfig)
  575. # Adapted from transformers.models.clip.modeling_clip.CLIPTextTransformer.forward with clip->clipseg, CLIP->CLIPSeg
  576. def forward(
  577. self,
  578. input_ids: Optional[torch.Tensor] = None,
  579. attention_mask: Optional[torch.Tensor] = None,
  580. position_ids: Optional[torch.Tensor] = None,
  581. output_attentions: Optional[bool] = None,
  582. output_hidden_states: Optional[bool] = None,
  583. return_dict: Optional[bool] = None,
  584. ) -> Union[Tuple, BaseModelOutputWithPooling]:
  585. r"""
  586. Returns:
  587. """
  588. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  589. output_hidden_states = (
  590. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  591. )
  592. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  593. if input_ids is None:
  594. raise ValueError("You have to specify input_ids")
  595. input_shape = input_ids.size()
  596. input_ids = input_ids.view(-1, input_shape[-1])
  597. hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
  598. # CLIPSeg's text model uses causal mask, prepare it here.
  599. # https://github.com/openai/CLIPSeg/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clipseg/model.py#L324
  600. causal_attention_mask = _create_4d_causal_attention_mask(
  601. input_shape, hidden_states.dtype, device=hidden_states.device
  602. )
  603. # expand attention_mask
  604. if attention_mask is not None:
  605. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  606. attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
  607. encoder_outputs = self.encoder(
  608. inputs_embeds=hidden_states,
  609. attention_mask=attention_mask,
  610. causal_attention_mask=causal_attention_mask,
  611. output_attentions=output_attentions,
  612. output_hidden_states=output_hidden_states,
  613. return_dict=return_dict,
  614. )
  615. last_hidden_state = encoder_outputs[0]
  616. last_hidden_state = self.final_layer_norm(last_hidden_state)
  617. if self.eos_token_id == 2:
  618. # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
  619. # A CLIPSeg model with such `eos_token_id` in the config can't work correctly with extra new tokens added
  620. # ------------------------------------------------------------
  621. # text_embeds.shape = [batch_size, sequence_length, transformer.width]
  622. # take features from the eot embedding (eot_token is the highest number in each sequence)
  623. # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
  624. pooled_output = last_hidden_state[
  625. torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
  626. input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
  627. ]
  628. else:
  629. # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
  630. pooled_output = last_hidden_state[
  631. torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
  632. # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
  633. # Note: we assume each sequence (along batch dim.) contains an `eos_token_id` (e.g. prepared by the tokenizer)
  634. (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id)
  635. .int()
  636. .argmax(dim=-1),
  637. ]
  638. if not return_dict:
  639. return (last_hidden_state, pooled_output) + encoder_outputs[1:]
  640. return BaseModelOutputWithPooling(
  641. last_hidden_state=last_hidden_state,
  642. pooler_output=pooled_output,
  643. hidden_states=encoder_outputs.hidden_states,
  644. attentions=encoder_outputs.attentions,
  645. )
  646. class CLIPSegTextModel(CLIPSegPreTrainedModel):
  647. config_class = CLIPSegTextConfig
  648. _no_split_modules = ["CLIPSegTextEmbeddings", "CLIPSegEncoderLayer"]
  649. def __init__(self, config: CLIPSegTextConfig):
  650. super().__init__(config)
  651. self.text_model = CLIPSegTextTransformer(config)
  652. # Initialize weights and apply final processing
  653. self.post_init()
  654. def get_input_embeddings(self) -> nn.Module:
  655. return self.text_model.embeddings.token_embedding
  656. def set_input_embeddings(self, value):
  657. self.text_model.embeddings.token_embedding = value
  658. @add_start_docstrings_to_model_forward(CLIPSEG_TEXT_INPUTS_DOCSTRING)
  659. @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPSegTextConfig)
  660. def forward(
  661. self,
  662. input_ids: Optional[torch.Tensor] = None,
  663. attention_mask: Optional[torch.Tensor] = None,
  664. position_ids: Optional[torch.Tensor] = None,
  665. output_attentions: Optional[bool] = None,
  666. output_hidden_states: Optional[bool] = None,
  667. return_dict: Optional[bool] = None,
  668. ) -> Union[Tuple, BaseModelOutputWithPooling]:
  669. r"""
  670. Returns:
  671. Examples:
  672. ```python
  673. >>> from transformers import AutoTokenizer, CLIPSegTextModel
  674. >>> tokenizer = AutoTokenizer.from_pretrained("CIDAS/clipseg-rd64-refined")
  675. >>> model = CLIPSegTextModel.from_pretrained("CIDAS/clipseg-rd64-refined")
  676. >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
  677. >>> outputs = model(**inputs)
  678. >>> last_hidden_state = outputs.last_hidden_state
  679. >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
  680. ```"""
  681. return self.text_model(
  682. input_ids=input_ids,
  683. attention_mask=attention_mask,
  684. position_ids=position_ids,
  685. output_attentions=output_attentions,
  686. output_hidden_states=output_hidden_states,
  687. return_dict=return_dict,
  688. )
  689. class CLIPSegVisionTransformer(nn.Module):
  690. # Copied from transformers.models.altclip.modeling_altclip.AltCLIPVisionTransformer.__init__ with AltCLIP->CLIPSeg
  691. def __init__(self, config: CLIPSegVisionConfig):
  692. super().__init__()
  693. self.config = config
  694. embed_dim = config.hidden_size
  695. self.embeddings = CLIPSegVisionEmbeddings(config)
  696. self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  697. self.encoder = CLIPSegEncoder(config)
  698. self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  699. @add_start_docstrings_to_model_forward(CLIPSEG_VISION_INPUTS_DOCSTRING)
  700. @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPSegVisionConfig)
  701. # Copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer.forward
  702. def forward(
  703. self,
  704. pixel_values: Optional[torch.FloatTensor] = None,
  705. output_attentions: Optional[bool] = None,
  706. output_hidden_states: Optional[bool] = None,
  707. return_dict: Optional[bool] = None,
  708. interpolate_pos_encoding: Optional[bool] = False,
  709. ) -> Union[Tuple, BaseModelOutputWithPooling]:
  710. r"""
  711. Returns:
  712. """
  713. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  714. output_hidden_states = (
  715. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  716. )
  717. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  718. if pixel_values is None:
  719. raise ValueError("You have to specify pixel_values")
  720. hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  721. hidden_states = self.pre_layrnorm(hidden_states)
  722. encoder_outputs = self.encoder(
  723. inputs_embeds=hidden_states,
  724. output_attentions=output_attentions,
  725. output_hidden_states=output_hidden_states,
  726. return_dict=return_dict,
  727. )
  728. last_hidden_state = encoder_outputs[0]
  729. pooled_output = last_hidden_state[:, 0, :]
  730. pooled_output = self.post_layernorm(pooled_output)
  731. if not return_dict:
  732. return (last_hidden_state, pooled_output) + encoder_outputs[1:]
  733. return BaseModelOutputWithPooling(
  734. last_hidden_state=last_hidden_state,
  735. pooler_output=pooled_output,
  736. hidden_states=encoder_outputs.hidden_states,
  737. attentions=encoder_outputs.attentions,
  738. )
  739. class CLIPSegVisionModel(CLIPSegPreTrainedModel):
  740. config_class = CLIPSegVisionConfig
  741. main_input_name = "pixel_values"
  742. def __init__(self, config: CLIPSegVisionConfig):
  743. super().__init__(config)
  744. self.vision_model = CLIPSegVisionTransformer(config)
  745. # Initialize weights and apply final processing
  746. self.post_init()
  747. def get_input_embeddings(self) -> nn.Module:
  748. return self.vision_model.embeddings.patch_embedding
  749. @add_start_docstrings_to_model_forward(CLIPSEG_VISION_INPUTS_DOCSTRING)
  750. @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPSegVisionConfig)
  751. def forward(
  752. self,
  753. pixel_values: Optional[torch.FloatTensor] = None,
  754. output_attentions: Optional[bool] = None,
  755. output_hidden_states: Optional[bool] = None,
  756. interpolate_pos_encoding: Optional[bool] = False,
  757. return_dict: Optional[bool] = None,
  758. ) -> Union[Tuple, BaseModelOutputWithPooling]:
  759. r"""
  760. Returns:
  761. Examples:
  762. ```python
  763. >>> from PIL import Image
  764. >>> import requests
  765. >>> from transformers import AutoProcessor, CLIPSegVisionModel
  766. >>> processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
  767. >>> model = CLIPSegVisionModel.from_pretrained("CIDAS/clipseg-rd64-refined")
  768. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  769. >>> image = Image.open(requests.get(url, stream=True).raw)
  770. >>> inputs = processor(images=image, return_tensors="pt")
  771. >>> outputs = model(**inputs)
  772. >>> last_hidden_state = outputs.last_hidden_state
  773. >>> pooled_output = outputs.pooler_output # pooled CLS states
  774. ```"""
  775. return self.vision_model(
  776. pixel_values=pixel_values,
  777. output_attentions=output_attentions,
  778. output_hidden_states=output_hidden_states,
  779. interpolate_pos_encoding=interpolate_pos_encoding,
  780. return_dict=return_dict,
  781. )
  782. @add_start_docstrings(CLIPSEG_START_DOCSTRING)
  783. class CLIPSegModel(CLIPSegPreTrainedModel):
  784. config_class = CLIPSegConfig
  785. def __init__(self, config: CLIPSegConfig):
  786. super().__init__(config)
  787. if not isinstance(config.text_config, CLIPSegTextConfig):
  788. raise TypeError(
  789. "config.text_config is expected to be of type CLIPSegTextConfig but is of type"
  790. f" {type(config.text_config)}."
  791. )
  792. if not isinstance(config.vision_config, CLIPSegVisionConfig):
  793. raise TypeError(
  794. "config.vision_config is expected to be of type CLIPSegVisionConfig but is of type"
  795. f" {type(config.vision_config)}."
  796. )
  797. text_config = config.text_config
  798. vision_config = config.vision_config
  799. self.projection_dim = config.projection_dim
  800. self.text_embed_dim = text_config.hidden_size
  801. self.vision_embed_dim = vision_config.hidden_size
  802. self.text_model = CLIPSegTextTransformer(text_config)
  803. self.vision_model = CLIPSegVisionTransformer(vision_config)
  804. self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
  805. self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
  806. self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
  807. # Initialize weights and apply final processing
  808. self.post_init()
  809. @add_start_docstrings_to_model_forward(CLIPSEG_TEXT_INPUTS_DOCSTRING)
  810. def get_text_features(
  811. self,
  812. input_ids: Optional[torch.Tensor] = None,
  813. attention_mask: Optional[torch.Tensor] = None,
  814. position_ids: Optional[torch.Tensor] = None,
  815. output_attentions: Optional[bool] = None,
  816. output_hidden_states: Optional[bool] = None,
  817. return_dict: Optional[bool] = None,
  818. ) -> torch.FloatTensor:
  819. r"""
  820. Returns:
  821. text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
  822. applying the projection layer to the pooled output of [`CLIPSegTextModel`].
  823. Examples:
  824. ```python
  825. >>> from transformers import AutoTokenizer, CLIPSegModel
  826. >>> tokenizer = AutoTokenizer.from_pretrained("CIDAS/clipseg-rd64-refined")
  827. >>> model = CLIPSegModel.from_pretrained("CIDAS/clipseg-rd64-refined")
  828. >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
  829. >>> text_features = model.get_text_features(**inputs)
  830. ```"""
  831. # Use CLIPSEG model's config for some fields (if specified) instead of those of vision & text components.
  832. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  833. output_hidden_states = (
  834. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  835. )
  836. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  837. text_outputs = self.text_model(
  838. input_ids=input_ids,
  839. attention_mask=attention_mask,
  840. position_ids=position_ids,
  841. output_attentions=output_attentions,
  842. output_hidden_states=output_hidden_states,
  843. return_dict=return_dict,
  844. )
  845. pooled_output = text_outputs[1]
  846. text_features = self.text_projection(pooled_output)
  847. return text_features
  848. @add_start_docstrings_to_model_forward(CLIPSEG_VISION_INPUTS_DOCSTRING)
  849. def get_image_features(
  850. self,
  851. pixel_values: Optional[torch.FloatTensor] = None,
  852. output_attentions: Optional[bool] = None,
  853. output_hidden_states: Optional[bool] = None,
  854. interpolate_pos_encoding: bool = False,
  855. return_dict: Optional[bool] = None,
  856. ) -> torch.FloatTensor:
  857. r"""
  858. Returns:
  859. image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
  860. applying the projection layer to the pooled output of [`CLIPSegVisionModel`].
  861. Examples:
  862. ```python
  863. >>> from PIL import Image
  864. >>> import requests
  865. >>> from transformers import AutoProcessor, CLIPSegModel
  866. >>> processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
  867. >>> model = CLIPSegModel.from_pretrained("CIDAS/clipseg-rd64-refined")
  868. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  869. >>> image = Image.open(requests.get(url, stream=True).raw)
  870. >>> inputs = processor(images=image, return_tensors="pt")
  871. >>> image_features = model.get_image_features(**inputs)
  872. ```"""
  873. # Use CLIPSEG model's config for some fields (if specified) instead of those of vision & text components.
  874. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  875. output_hidden_states = (
  876. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  877. )
  878. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  879. vision_outputs = self.vision_model(
  880. pixel_values=pixel_values,
  881. output_attentions=output_attentions,
  882. output_hidden_states=output_hidden_states,
  883. interpolate_pos_encoding=interpolate_pos_encoding,
  884. return_dict=return_dict,
  885. )
  886. pooled_output = vision_outputs[1] # pooled_output
  887. image_features = self.visual_projection(pooled_output)
  888. return image_features
  889. @add_start_docstrings_to_model_forward(CLIPSEG_INPUTS_DOCSTRING)
  890. @replace_return_docstrings(output_type=CLIPSegOutput, config_class=CLIPSegConfig)
  891. def forward(
  892. self,
  893. input_ids: Optional[torch.LongTensor] = None,
  894. pixel_values: Optional[torch.FloatTensor] = None,
  895. attention_mask: Optional[torch.Tensor] = None,
  896. position_ids: Optional[torch.LongTensor] = None,
  897. return_loss: Optional[bool] = None,
  898. output_attentions: Optional[bool] = None,
  899. output_hidden_states: Optional[bool] = None,
  900. interpolate_pos_encoding: bool = False,
  901. return_dict: Optional[bool] = None,
  902. ) -> Union[Tuple, CLIPSegOutput]:
  903. r"""
  904. Returns:
  905. Examples:
  906. ```python
  907. >>> from PIL import Image
  908. >>> import requests
  909. >>> from transformers import AutoProcessor, CLIPSegModel
  910. >>> processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
  911. >>> model = CLIPSegModel.from_pretrained("CIDAS/clipseg-rd64-refined")
  912. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  913. >>> image = Image.open(requests.get(url, stream=True).raw)
  914. >>> inputs = processor(
  915. ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
  916. ... )
  917. >>> outputs = model(**inputs)
  918. >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
  919. >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
  920. ```"""
  921. # Use CLIPSEG model's config for some fields (if specified) instead of those of vision & text components.
  922. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  923. output_hidden_states = (
  924. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  925. )
  926. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  927. vision_outputs = self.vision_model(
  928. pixel_values=pixel_values,
  929. output_attentions=output_attentions,
  930. output_hidden_states=output_hidden_states,
  931. interpolate_pos_encoding=interpolate_pos_encoding,
  932. return_dict=return_dict,
  933. )
  934. text_outputs = self.text_model(
  935. input_ids=input_ids,
  936. attention_mask=attention_mask,
  937. position_ids=position_ids,
  938. output_attentions=output_attentions,
  939. output_hidden_states=output_hidden_states,
  940. return_dict=return_dict,
  941. )
  942. image_embeds = vision_outputs[1]
  943. image_embeds = self.visual_projection(image_embeds)
  944. text_embeds = text_outputs[1]
  945. text_embeds = self.text_projection(text_embeds)
  946. # normalized features
  947. image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
  948. text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
  949. # cosine similarity as logits
  950. logit_scale = self.logit_scale.exp()
  951. logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
  952. logits_per_image = logits_per_text.t()
  953. loss = None
  954. if return_loss:
  955. loss = clipseg_loss(logits_per_text)
  956. if not return_dict:
  957. output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
  958. return ((loss,) + output) if loss is not None else output
  959. return CLIPSegOutput(
  960. loss=loss,
  961. logits_per_image=logits_per_image,
  962. logits_per_text=logits_per_text,
  963. text_embeds=text_embeds,
  964. image_embeds=image_embeds,
  965. text_model_output=text_outputs,
  966. vision_model_output=vision_outputs,
  967. )
  968. class CLIPSegDecoderLayer(nn.Module):
  969. """
  970. CLIPSeg decoder layer, which is identical to `CLIPSegEncoderLayer`, except that normalization is applied after
  971. self-attention/MLP, rather than before.
  972. """
  973. # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer.__init__ with AltCLIP->CLIPSeg
  974. def __init__(self, config: CLIPSegConfig):
  975. super().__init__()
  976. self.embed_dim = config.hidden_size
  977. self.self_attn = CLIPSegAttention(config)
  978. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  979. self.mlp = CLIPSegMLP(config)
  980. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  981. def forward(
  982. self,
  983. hidden_states: torch.Tensor,
  984. attention_mask: torch.Tensor,
  985. causal_attention_mask: torch.Tensor,
  986. output_attentions: Optional[bool] = False,
  987. ) -> Tuple[torch.FloatTensor]:
  988. """
  989. Args:
  990. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  991. attention_mask (`torch.FloatTensor`): attention mask of size
  992. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  993. `(config.encoder_attention_heads,)`.
  994. output_attentions (`bool`, *optional*):
  995. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  996. returned tensors for more detail.
  997. """
  998. residual = hidden_states
  999. hidden_states, attn_weights = self.self_attn(
  1000. hidden_states=hidden_states,
  1001. attention_mask=attention_mask,
  1002. causal_attention_mask=causal_attention_mask,
  1003. output_attentions=output_attentions,
  1004. )
  1005. hidden_states = residual + hidden_states
  1006. hidden_states = self.layer_norm1(hidden_states)
  1007. residual = hidden_states
  1008. hidden_states = self.mlp(hidden_states)
  1009. hidden_states = residual + hidden_states
  1010. hidden_states = self.layer_norm2(hidden_states)
  1011. outputs = (hidden_states,)
  1012. if output_attentions:
  1013. outputs += (attn_weights,)
  1014. return outputs
  1015. class CLIPSegDecoder(CLIPSegPreTrainedModel):
  1016. def __init__(self, config: CLIPSegConfig):
  1017. super().__init__(config)
  1018. self.conditional_layer = config.conditional_layer
  1019. self.film_mul = nn.Linear(config.projection_dim, config.reduce_dim)
  1020. self.film_add = nn.Linear(config.projection_dim, config.reduce_dim)
  1021. if config.use_complex_transposed_convolution:
  1022. transposed_kernels = (config.vision_config.patch_size // 4, config.vision_config.patch_size // 4)
  1023. self.transposed_convolution = nn.Sequential(
  1024. nn.Conv2d(config.reduce_dim, config.reduce_dim, kernel_size=3, padding=1),
  1025. nn.ReLU(),
  1026. nn.ConvTranspose2d(
  1027. config.reduce_dim,
  1028. config.reduce_dim // 2,
  1029. kernel_size=transposed_kernels[0],
  1030. stride=transposed_kernels[0],
  1031. ),
  1032. nn.ReLU(),
  1033. nn.ConvTranspose2d(
  1034. config.reduce_dim // 2, 1, kernel_size=transposed_kernels[1], stride=transposed_kernels[1]
  1035. ),
  1036. )
  1037. else:
  1038. self.transposed_convolution = nn.ConvTranspose2d(
  1039. config.reduce_dim, 1, config.vision_config.patch_size, stride=config.vision_config.patch_size
  1040. )
  1041. depth = len(config.extract_layers)
  1042. self.reduces = nn.ModuleList(
  1043. [nn.Linear(config.vision_config.hidden_size, config.reduce_dim) for _ in range(depth)]
  1044. )
  1045. decoder_config = copy.deepcopy(config.vision_config)
  1046. decoder_config.hidden_size = config.reduce_dim
  1047. decoder_config.num_attention_heads = config.decoder_num_attention_heads
  1048. decoder_config.intermediate_size = config.decoder_intermediate_size
  1049. decoder_config.hidden_act = "relu"
  1050. self.layers = nn.ModuleList([CLIPSegDecoderLayer(decoder_config) for _ in range(len(config.extract_layers))])
  1051. def forward(
  1052. self,
  1053. hidden_states: Tuple[torch.Tensor],
  1054. conditional_embeddings: torch.Tensor,
  1055. output_attentions: Optional[bool] = None,
  1056. output_hidden_states: Optional[bool] = None,
  1057. return_dict: Optional[bool] = True,
  1058. ):
  1059. all_hidden_states = () if output_hidden_states else None
  1060. all_attentions = () if output_attentions else None
  1061. activations = hidden_states[::-1]
  1062. output = None
  1063. for i, (activation, layer, reduce) in enumerate(zip(activations, self.layers, self.reduces)):
  1064. if output is not None:
  1065. output = reduce(activation) + output
  1066. else:
  1067. output = reduce(activation)
  1068. if i == self.conditional_layer:
  1069. output = self.film_mul(conditional_embeddings) * output.permute(1, 0, 2) + self.film_add(
  1070. conditional_embeddings
  1071. )
  1072. output = output.permute(1, 0, 2)
  1073. layer_outputs = layer(
  1074. output, attention_mask=None, causal_attention_mask=None, output_attentions=output_attentions
  1075. )
  1076. output = layer_outputs[0]
  1077. if output_hidden_states:
  1078. all_hidden_states += (output,)
  1079. if output_attentions:
  1080. all_attentions += (layer_outputs[1],)
  1081. output = output[:, 1:, :].permute(0, 2, 1) # remove cls token and reshape to [batch_size, reduce_dim, seq_len]
  1082. size = int(math.sqrt(output.shape[2]))
  1083. batch_size = conditional_embeddings.shape[0]
  1084. output = output.view(batch_size, output.shape[1], size, size)
  1085. logits = self.transposed_convolution(output).squeeze(1)
  1086. if not return_dict:
  1087. return tuple(v for v in [logits, all_hidden_states, all_attentions] if v is not None)
  1088. return CLIPSegDecoderOutput(
  1089. logits=logits,
  1090. hidden_states=all_hidden_states,
  1091. attentions=all_attentions,
  1092. )
  1093. @add_start_docstrings(
  1094. """
  1095. CLIPSeg model with a Transformer-based decoder on top for zero-shot and one-shot image segmentation.
  1096. """,
  1097. CLIPSEG_START_DOCSTRING,
  1098. )
  1099. class CLIPSegForImageSegmentation(CLIPSegPreTrainedModel):
  1100. config_class = CLIPSegConfig
  1101. def __init__(self, config: CLIPSegConfig):
  1102. super().__init__(config)
  1103. self.config = config
  1104. self.clip = CLIPSegModel(config)
  1105. self.extract_layers = config.extract_layers
  1106. self.decoder = CLIPSegDecoder(config)
  1107. # Initialize weights and apply final processing
  1108. self.post_init()
  1109. def get_conditional_embeddings(
  1110. self,
  1111. batch_size: int = None,
  1112. input_ids: Optional[torch.Tensor] = None,
  1113. attention_mask: Optional[torch.Tensor] = None,
  1114. position_ids: Optional[torch.Tensor] = None,
  1115. conditional_pixel_values: Optional[torch.Tensor] = None,
  1116. ):
  1117. if input_ids is not None:
  1118. # compute conditional embeddings from texts
  1119. if len(input_ids) != batch_size:
  1120. raise ValueError("Make sure to pass as many prompt texts as there are query images")
  1121. with torch.no_grad():
  1122. conditional_embeddings = self.clip.get_text_features(
  1123. input_ids, attention_mask=attention_mask, position_ids=position_ids
  1124. )
  1125. elif conditional_pixel_values is not None:
  1126. # compute conditional embeddings from images
  1127. if len(conditional_pixel_values) != batch_size:
  1128. raise ValueError("Make sure to pass as many prompt images as there are query images")
  1129. with torch.no_grad():
  1130. conditional_embeddings = self.clip.get_image_features(conditional_pixel_values)
  1131. else:
  1132. raise ValueError(
  1133. "Invalid conditional, should be either provided as `input_ids` or `conditional_pixel_values`"
  1134. )
  1135. return conditional_embeddings
  1136. @add_start_docstrings_to_model_forward(CLIPSEG_INPUTS_DOCSTRING)
  1137. @replace_return_docstrings(output_type=CLIPSegImageSegmentationOutput, config_class=CLIPSegTextConfig)
  1138. def forward(
  1139. self,
  1140. input_ids: Optional[torch.FloatTensor] = None,
  1141. pixel_values: Optional[torch.FloatTensor] = None,
  1142. conditional_pixel_values: Optional[torch.FloatTensor] = None,
  1143. conditional_embeddings: Optional[torch.FloatTensor] = None,
  1144. attention_mask: Optional[torch.Tensor] = None,
  1145. position_ids: Optional[torch.LongTensor] = None,
  1146. labels: Optional[torch.LongTensor] = None,
  1147. output_attentions: Optional[bool] = None,
  1148. output_hidden_states: Optional[bool] = None,
  1149. interpolate_pos_encoding: bool = False,
  1150. return_dict: Optional[bool] = None,
  1151. ) -> Union[Tuple, CLIPSegOutput]:
  1152. r"""
  1153. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1154. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1155. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1156. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1157. Returns:
  1158. Examples:
  1159. ```python
  1160. >>> from transformers import AutoProcessor, CLIPSegForImageSegmentation
  1161. >>> from PIL import Image
  1162. >>> import requests
  1163. >>> processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
  1164. >>> model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
  1165. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1166. >>> image = Image.open(requests.get(url, stream=True).raw)
  1167. >>> texts = ["a cat", "a remote", "a blanket"]
  1168. >>> inputs = processor(text=texts, images=[image] * len(texts), padding=True, return_tensors="pt")
  1169. >>> outputs = model(**inputs)
  1170. >>> logits = outputs.logits
  1171. >>> print(logits.shape)
  1172. torch.Size([3, 352, 352])
  1173. ```"""
  1174. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1175. # step 1: forward the query images through the frozen CLIP vision encoder
  1176. with torch.no_grad():
  1177. vision_outputs = self.clip.vision_model(
  1178. pixel_values=pixel_values,
  1179. output_attentions=output_attentions,
  1180. output_hidden_states=True, # we need the intermediate hidden states
  1181. interpolate_pos_encoding=interpolate_pos_encoding,
  1182. return_dict=return_dict,
  1183. )
  1184. pooled_output = self.clip.visual_projection(vision_outputs[1])
  1185. hidden_states = vision_outputs.hidden_states if return_dict else vision_outputs[2]
  1186. # we add +1 here as the hidden states also include the initial embeddings
  1187. activations = [hidden_states[i + 1] for i in self.extract_layers]
  1188. # update vision_outputs
  1189. if return_dict:
  1190. vision_outputs = BaseModelOutputWithPooling(
  1191. last_hidden_state=vision_outputs.last_hidden_state,
  1192. pooler_output=vision_outputs.pooler_output,
  1193. hidden_states=vision_outputs.hidden_states if output_hidden_states else None,
  1194. attentions=vision_outputs.attentions,
  1195. )
  1196. else:
  1197. vision_outputs = (
  1198. vision_outputs[:2] + vision_outputs[3:] if not output_hidden_states else vision_outputs
  1199. )
  1200. # step 2: compute conditional embeddings, either from text, images or an own provided embedding
  1201. if conditional_embeddings is None:
  1202. conditional_embeddings = self.get_conditional_embeddings(
  1203. batch_size=pixel_values.shape[0],
  1204. input_ids=input_ids,
  1205. attention_mask=attention_mask,
  1206. position_ids=position_ids,
  1207. conditional_pixel_values=conditional_pixel_values,
  1208. )
  1209. else:
  1210. if conditional_embeddings.shape[0] != pixel_values.shape[0]:
  1211. raise ValueError(
  1212. "Make sure to pass as many conditional embeddings as there are query images in the batch"
  1213. )
  1214. if conditional_embeddings.shape[1] != self.config.projection_dim:
  1215. raise ValueError(
  1216. "Make sure that the feature dimension of the conditional embeddings matches"
  1217. " `config.projection_dim`."
  1218. )
  1219. # step 3: forward both the pooled output and the activations through the lightweight decoder to predict masks
  1220. decoder_outputs = self.decoder(
  1221. activations,
  1222. conditional_embeddings,
  1223. output_attentions=output_attentions,
  1224. output_hidden_states=output_hidden_states,
  1225. return_dict=return_dict,
  1226. )
  1227. logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
  1228. loss = None
  1229. if labels is not None:
  1230. # move labels to the correct device to enable PP
  1231. labels = labels.to(logits.device)
  1232. loss_fn = nn.BCEWithLogitsLoss()
  1233. loss = loss_fn(logits, labels)
  1234. if not return_dict:
  1235. output = (logits, conditional_embeddings, pooled_output, vision_outputs, decoder_outputs)
  1236. return ((loss,) + output) if loss is not None else output
  1237. return CLIPSegImageSegmentationOutput(
  1238. loss=loss,
  1239. logits=logits,
  1240. conditional_embeddings=conditional_embeddings,
  1241. pooled_output=pooled_output,
  1242. vision_model_output=vision_outputs,
  1243. decoder_output=decoder_outputs,
  1244. )