modeling_tf_groupvit.py 88 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138
  1. # coding=utf-8
  2. # Copyright 2022 NVIDIA and The HuggingFace Team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """TF 2.0 GroupViT model."""
  16. from __future__ import annotations
  17. import collections.abc
  18. import math
  19. from dataclasses import dataclass
  20. from typing import Any, Optional, Tuple, Union
  21. import numpy as np
  22. import tensorflow as tf
  23. from ...activations_tf import get_tf_activation
  24. from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling
  25. from ...modeling_tf_utils import (
  26. TFModelInputType,
  27. TFPreTrainedModel,
  28. get_initializer,
  29. keras,
  30. keras_serializable,
  31. unpack_inputs,
  32. )
  33. from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
  34. from ...utils import (
  35. ModelOutput,
  36. add_start_docstrings,
  37. add_start_docstrings_to_model_forward,
  38. is_tensorflow_probability_available,
  39. logging,
  40. replace_return_docstrings,
  41. )
  42. from .configuration_groupvit import GroupViTConfig, GroupViTTextConfig, GroupViTVisionConfig
  43. logger = logging.get_logger(__name__)
  44. # soft dependency
  45. if is_tensorflow_probability_available():
  46. try:
  47. import tensorflow_probability as tfp
  48. # On the first call, check whether a compatible version of TensorFlow is installed
  49. # TensorFlow Probability depends on a recent stable release of TensorFlow
  50. _ = tfp.distributions.Normal(loc=0.0, scale=1.0)
  51. except ImportError:
  52. logger.error(
  53. "GroupViT models are not usable since `tensorflow_probability` can't be loaded. "
  54. "It seems you have `tensorflow_probability` installed with the wrong tensorflow version."
  55. "Please try to reinstall it following the instructions here: https://github.com/tensorflow/probability."
  56. )
  57. else:
  58. try:
  59. import tensorflow_probability as tfp
  60. # On the first call, check whether a compatible version of TensorFlow is installed
  61. # TensorFlow Probability depends on a recent stable release of TensorFlow
  62. _ = tfp.distributions.Normal(loc=0.0, scale=1.0)
  63. except ImportError:
  64. pass
  65. _CHECKPOINT_FOR_DOC = "nvidia/groupvit-gcc-yfcc"
  66. LARGE_NEGATIVE = -1e8
  67. # Copied from transformers.models.bart.modeling_tf_bart._expand_mask
  68. def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
  69. """
  70. Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
  71. """
  72. src_len = shape_list(mask)[1]
  73. tgt_len = tgt_len if tgt_len is not None else src_len
  74. one_cst = tf.constant(1.0)
  75. mask = tf.cast(mask, dtype=one_cst.dtype)
  76. expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1))
  77. return (one_cst - expanded_mask) * LARGE_NEGATIVE
  78. # contrastive loss function, adapted from
  79. # https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
  80. def contrastive_loss(logits: tf.Tensor) -> tf.Tensor:
  81. return tf.math.reduce_mean(
  82. keras.metrics.sparse_categorical_crossentropy(
  83. y_true=tf.range(shape_list(logits)[0]), y_pred=logits, from_logits=True
  84. )
  85. )
  86. # Copied from transformers.models.clip.modeling_tf_clip.clip_loss with clip->groupvit
  87. def groupvit_loss(similarity: tf.Tensor) -> tf.Tensor:
  88. caption_loss = contrastive_loss(similarity)
  89. image_loss = contrastive_loss(tf.transpose(similarity))
  90. return (caption_loss + image_loss) / 2.0
  91. def hard_softmax(logits: tf.Tensor, dim: int) -> tf.Tensor:
  92. y_soft = stable_softmax(logits, dim)
  93. # Straight through.
  94. index = tf.argmax(y_soft, dim)
  95. y_hard = tf.one_hot(
  96. index,
  97. depth=shape_list(logits)[dim],
  98. # TensorFlow expects axis to be -1 or between [0, 3). But received: -2
  99. # This is why the following code snippet is used.
  100. axis=range(len(shape_list(logits)))[dim],
  101. dtype=y_soft.dtype,
  102. )
  103. ret = y_hard - tf.stop_gradient(y_soft) + y_soft
  104. return ret
  105. def gumbel_softmax(logits: tf.Tensor, tau: float = 1, hard: bool = False, dim: int = -1) -> tf.Tensor:
  106. gumbel_dist = tfp.distributions.Gumbel(0.0, 1.0)
  107. gumbels = gumbel_dist.sample(tf.shape(logits), dtype=logits.dtype)
  108. gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau)
  109. y_soft = stable_softmax(gumbels, dim)
  110. if hard:
  111. # Straight through.
  112. index = tf.argmax(y_soft, dim)
  113. y_hard = tf.one_hot(
  114. index,
  115. depth=shape_list(logits)[dim],
  116. # TensorFlow expects axis to be -1 or between [0, 3). But received: -2
  117. # This is why the following code snippet is used.
  118. axis=range(len(shape_list(logits)))[dim],
  119. dtype=y_soft.dtype,
  120. )
  121. ret = y_hard - tf.stop_gradient(y_soft) + y_soft
  122. else:
  123. # Reparametrization trick.
  124. ret = y_soft
  125. return ret
  126. def resize_attention_map(attentions: tf.Tensor, height: int, width: int, align_corners: bool = False) -> tf.Tensor:
  127. """
  128. Args:
  129. attentions (`tf.Tensor`): attention map of shape [batch_size, groups, feat_height*feat_width]
  130. height (`int`): height of the output attention map
  131. width (`int`): width of the output attention map
  132. align_corners (`bool`, *optional*): the `align_corner` argument for `nn.functional.interpolate`.
  133. Returns:
  134. `tf.Tensor`: resized attention map of shape [batch_size, groups, height, width]
  135. """
  136. scale = (height * width // attentions.shape[2]) ** 0.5
  137. if height > width:
  138. feat_width = int(np.round(width / scale))
  139. feat_height = shape_list(attentions)[2] // feat_width
  140. else:
  141. feat_height = int(np.round(height / scale))
  142. feat_width = shape_list(attentions)[2] // feat_height
  143. batch_size = shape_list(attentions)[0]
  144. groups = shape_list(attentions)[1] # number of group token
  145. # [batch_size, groups, height x width, groups] -> [batch_size, groups, height, width]
  146. attentions = tf.reshape(attentions, (batch_size, groups, feat_height, feat_width))
  147. attentions = tf.transpose(attentions, perm=(0, 2, 3, 1))
  148. if align_corners:
  149. attentions = tf.compat.v1.image.resize(
  150. attentions,
  151. size=(height, width),
  152. method="bilinear",
  153. align_corners=align_corners,
  154. )
  155. else:
  156. attentions = tf.image.resize(attentions, size=(height, width), method="bilinear")
  157. attentions = tf.transpose(attentions, perm=(0, 3, 1, 2))
  158. return attentions
  159. def get_grouping_from_attentions(attentions: Tuple[tf.Tensor], hw_shape: Tuple[int]) -> tf.Tensor:
  160. """
  161. Args:
  162. attentions (`tuple(tf.Tensor)`: tuple of attention maps returned by `TFGroupViTVisionTransformer`
  163. hw_shape (`tuple(int)`): height and width of the output attention map
  164. Returns:
  165. `tf.Tensor`: the attention map of shape [batch_size, groups, height, width]
  166. """
  167. attn_maps = []
  168. prev_attn_masks = None
  169. for attn_masks in attentions:
  170. # [batch_size, num_groups, height x width] -> [batch_size, height x width, num_groups]
  171. attn_masks = tf.transpose(attn_masks, perm=(0, 2, 1))
  172. if prev_attn_masks is None:
  173. prev_attn_masks = attn_masks
  174. else:
  175. prev_attn_masks = tf.matmul(prev_attn_masks, attn_masks)
  176. # [batch_size, height x width, num_groups] -> [batch_size, num_groups, height x width] -> [batch_size, num_groups, height, width]
  177. cur_attn_map = resize_attention_map(tf.transpose(prev_attn_masks, perm=(0, 2, 1)), *hw_shape)
  178. attn_maps.append(cur_attn_map)
  179. # [batch_size, num_groups, height, width]
  180. final_grouping = attn_maps[-1]
  181. return tf.stop_gradient(final_grouping)
  182. @dataclass
  183. class TFGroupViTModelOutput(ModelOutput):
  184. """
  185. Args:
  186. loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
  187. Contrastive loss for image-text similarity.
  188. logits_per_image (`tf.Tensor` of shape `(image_batch_size, text_batch_size)`):
  189. The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
  190. similarity scores.
  191. logits_per_text (`tf.Tensor` of shape `(text_batch_size, image_batch_size)`):
  192. The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
  193. similarity scores.
  194. segmentation_logits (`tf.Tensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`):
  195. Classification scores for each pixel.
  196. <Tip warning={true}>
  197. The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is
  198. to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the
  199. original image size as post-processing. You should always check your logits shape and resize as needed.
  200. </Tip>
  201. text_embeds (`tf.Tensor` of shape `(batch_size, output_dim`):
  202. The text embeddings obtained by applying the projection layer to the pooled output of
  203. [`TFGroupViTTextModel`].
  204. image_embeds (`tf.Tensor` of shape `(batch_size, output_dim`):
  205. The image embeddings obtained by applying the projection layer to the pooled output of
  206. [`TFGroupViTVisionModel`].
  207. text_model_output (`TFBaseModelOutputWithPooling`):
  208. The output of the [`TFGroupViTTextModel`].
  209. vision_model_output (`TFBaseModelOutputWithPooling`):
  210. The output of the [`TFGroupViTVisionModel`].
  211. """
  212. loss: tf.Tensor | None = None
  213. logits_per_image: tf.Tensor = None
  214. logits_per_text: tf.Tensor = None
  215. segmentation_logits: tf.Tensor = None
  216. text_embeds: tf.Tensor = None
  217. image_embeds: tf.Tensor = None
  218. text_model_output: TFBaseModelOutputWithPooling = None
  219. vision_model_output: TFBaseModelOutputWithPooling = None
  220. def to_tuple(self) -> Tuple[Any]:
  221. return tuple(
  222. self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
  223. for k in self.keys()
  224. )
  225. class TFGroupViTCrossAttentionLayer(keras.layers.Layer):
  226. def __init__(self, config: GroupViTVisionConfig, **kwargs):
  227. super().__init__(**kwargs)
  228. self.attn = TFGroupViTAttention(config, name="attn")
  229. self.norm2 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="norm2")
  230. self.mlp = TFGroupViTMLP(config, name="mlp")
  231. self.norm_post = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="norm_post")
  232. self.config = config
  233. def call(self, query: tf.Tensor, key: tf.Tensor, training: bool = False) -> tf.Tensor:
  234. x = query
  235. x = x + self.attn(query, encoder_hidden_states=key)[0]
  236. x = x + self.mlp(self.norm2(x))
  237. x = self.norm_post(x)
  238. return x
  239. def build(self, input_shape=None):
  240. if self.built:
  241. return
  242. self.built = True
  243. if getattr(self, "attn", None) is not None:
  244. with tf.name_scope(self.attn.name):
  245. self.attn.build(None)
  246. if getattr(self, "norm2", None) is not None:
  247. with tf.name_scope(self.norm2.name):
  248. self.norm2.build([None, None, self.config.hidden_size])
  249. if getattr(self, "mlp", None) is not None:
  250. with tf.name_scope(self.mlp.name):
  251. self.mlp.build(None)
  252. if getattr(self, "norm_post", None) is not None:
  253. with tf.name_scope(self.norm_post.name):
  254. self.norm_post.build([None, None, self.config.hidden_size])
  255. class TFGroupViTAssignAttention(keras.layers.Layer):
  256. def __init__(self, config: GroupViTVisionConfig, **kwargs):
  257. super().__init__(**kwargs)
  258. self.scale = config.hidden_size**-0.5
  259. self.q_proj = keras.layers.Dense(config.hidden_size, name="q_proj")
  260. self.k_proj = keras.layers.Dense(config.hidden_size, name="k_proj")
  261. self.v_proj = keras.layers.Dense(config.hidden_size, name="v_proj")
  262. self.proj = keras.layers.Dense(config.hidden_size, name="proj")
  263. self.assign_eps = config.assign_eps
  264. self.config = config
  265. def get_attn(self, attn: tf.Tensor, gumbel: bool = True, hard: bool = True, training: bool = False) -> tf.Tensor:
  266. if gumbel and training:
  267. attn = gumbel_softmax(attn, dim=-2, hard=hard)
  268. else:
  269. if hard:
  270. attn = hard_softmax(attn, dim=-2)
  271. else:
  272. attn = stable_softmax(attn, axis=-2)
  273. return attn
  274. def call(self, query: tf.Tensor, key: tf.Tensor, training: bool = False):
  275. value = key
  276. # [batch_size, query_length, channels]
  277. query = self.q_proj(query)
  278. # [batch_size, key_length, channels]
  279. key = self.k_proj(key)
  280. # [batch_size, key_length, channels]
  281. value = self.v_proj(value)
  282. # [batch_size, query_length, key_length]
  283. raw_attn = tf.matmul(query, key, transpose_b=True) * self.scale
  284. attn = self.get_attn(raw_attn, training=training)
  285. soft_attn = self.get_attn(raw_attn, training=training, gumbel=False, hard=False)
  286. attn = attn / (tf.math.reduce_sum(attn, axis=-1, keepdims=True) + self.assign_eps)
  287. out = tf.matmul(attn, value)
  288. out = self.proj(out)
  289. return out, soft_attn
  290. def build(self, input_shape=None):
  291. if self.built:
  292. return
  293. self.built = True
  294. if getattr(self, "q_proj", None) is not None:
  295. with tf.name_scope(self.q_proj.name):
  296. self.q_proj.build([None, None, self.config.hidden_size])
  297. if getattr(self, "k_proj", None) is not None:
  298. with tf.name_scope(self.k_proj.name):
  299. self.k_proj.build([None, None, self.config.hidden_size])
  300. if getattr(self, "v_proj", None) is not None:
  301. with tf.name_scope(self.v_proj.name):
  302. self.v_proj.build([None, None, self.config.hidden_size])
  303. if getattr(self, "proj", None) is not None:
  304. with tf.name_scope(self.proj.name):
  305. self.proj.build([None, None, self.config.hidden_size])
  306. class TFGroupViTTokenAssign(keras.layers.Layer):
  307. def __init__(self, config: GroupViTVisionConfig, num_group_token: int, num_output_group: int, **kwargs):
  308. super().__init__(**kwargs)
  309. self.num_output_group = num_output_group
  310. # norm on group_tokens
  311. self.norm_tokens = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="norm_tokens")
  312. assign_mlp_ratio = (
  313. config.assign_mlp_ratio
  314. if isinstance(config.assign_mlp_ratio, collections.abc.Iterable)
  315. else (config.assign_mlp_ratio, config.assign_mlp_ratio)
  316. )
  317. tokens_dim, channels_dim = [int(x * config.hidden_size) for x in assign_mlp_ratio]
  318. self.mlp_inter = TFGroupViTMixerMLP(config, num_group_token, tokens_dim, num_output_group, name="mlp_inter")
  319. self.norm_post_tokens = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="norm_post_tokens")
  320. # norm on x
  321. self.norm_x = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="norm_x")
  322. self.pre_assign_attn = TFGroupViTCrossAttentionLayer(config, name="pre_assign_attn")
  323. self.assign = TFGroupViTAssignAttention(config, name="assign")
  324. self.norm_new_x = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="norm_new_x")
  325. self.mlp_channels = TFGroupViTMLP(
  326. config, config.hidden_size, channels_dim, config.hidden_size, name="mlp_channels"
  327. )
  328. self.config = config
  329. def project_group_token(self, group_tokens: tf.Tensor) -> tf.Tensor:
  330. """
  331. Args:
  332. group_tokens (tf.Tensor): group tokens, [batch_size, num_group_tokens, channels]
  333. Returns:
  334. projected_group_tokens (tf.Tensor): [batch_size, num_output_groups, channels]
  335. """
  336. # [B, num_output_groups, C] <- [B, num_group_tokens, C]
  337. projected_group_tokens = self.mlp_inter(group_tokens)
  338. projected_group_tokens = self.norm_post_tokens(projected_group_tokens)
  339. return projected_group_tokens
  340. def call(self, image_tokens: tf.Tensor, group_tokens: tf.Tensor, training: bool = False):
  341. """
  342. Args:
  343. image_tokens (`tf.Tensor`): image tokens, of shape [batch_size, input_length, channels]
  344. group_tokens (`tf.Tensor`): group tokens, [batch_size, num_group_tokens, channels]
  345. """
  346. group_tokens = self.norm_tokens(group_tokens)
  347. image_tokens = self.norm_x(image_tokens)
  348. # [batch_size, num_output_groups, channels]
  349. projected_group_tokens = self.project_group_token(group_tokens)
  350. projected_group_tokens = self.pre_assign_attn(projected_group_tokens, image_tokens)
  351. new_image_tokens, attention = self.assign(projected_group_tokens, image_tokens)
  352. new_image_tokens += projected_group_tokens
  353. new_image_tokens = new_image_tokens + self.mlp_channels(self.norm_new_x(new_image_tokens))
  354. return new_image_tokens, attention
  355. def build(self, input_shape=None):
  356. if self.built:
  357. return
  358. self.built = True
  359. if getattr(self, "norm_tokens", None) is not None:
  360. with tf.name_scope(self.norm_tokens.name):
  361. self.norm_tokens.build([None, None, self.config.hidden_size])
  362. if getattr(self, "mlp_inter", None) is not None:
  363. with tf.name_scope(self.mlp_inter.name):
  364. self.mlp_inter.build(None)
  365. if getattr(self, "norm_post_tokens", None) is not None:
  366. with tf.name_scope(self.norm_post_tokens.name):
  367. self.norm_post_tokens.build([None, None, self.config.hidden_size])
  368. if getattr(self, "norm_x", None) is not None:
  369. with tf.name_scope(self.norm_x.name):
  370. self.norm_x.build([None, None, self.config.hidden_size])
  371. if getattr(self, "pre_assign_attn", None) is not None:
  372. with tf.name_scope(self.pre_assign_attn.name):
  373. self.pre_assign_attn.build(None)
  374. if getattr(self, "assign", None) is not None:
  375. with tf.name_scope(self.assign.name):
  376. self.assign.build(None)
  377. if getattr(self, "norm_new_x", None) is not None:
  378. with tf.name_scope(self.norm_new_x.name):
  379. self.norm_new_x.build([None, None, self.config.hidden_size])
  380. if getattr(self, "mlp_channels", None) is not None:
  381. with tf.name_scope(self.mlp_channels.name):
  382. self.mlp_channels.build(None)
  383. # Adapted from transformers.models.vit.modeling_tf_vit.TFViTPatchEmbeddings with ViT->GroupViT
  384. class TFGroupViTPatchEmbeddings(keras.layers.Layer):
  385. """
  386. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  387. `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
  388. Transformer.
  389. """
  390. def __init__(self, config: GroupViTConfig, **kwargs):
  391. super().__init__(**kwargs)
  392. image_size, patch_size = config.image_size, config.patch_size
  393. num_channels = config.num_channels
  394. # hidden_size is a member as it will be required in the call method
  395. self.hidden_size = config.hidden_size
  396. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  397. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  398. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  399. self.image_size = image_size
  400. self.patch_size = patch_size
  401. self.num_patches = num_patches
  402. self.num_channels = num_channels
  403. self.config = config
  404. self.projection = keras.layers.Conv2D(
  405. filters=self.hidden_size,
  406. kernel_size=patch_size,
  407. strides=patch_size,
  408. padding="valid",
  409. data_format="channels_last",
  410. use_bias=True,
  411. kernel_initializer=get_initializer(self.config.initializer_range),
  412. bias_initializer="zeros",
  413. name="projection",
  414. )
  415. def call(
  416. self, pixel_values: tf.Tensor, interpolate_pos_encoding: bool = False, training: bool = False
  417. ) -> tf.Tensor:
  418. batch_size, num_channels, height, width = shape_list(pixel_values)
  419. if tf.executing_eagerly() and num_channels != self.num_channels:
  420. raise ValueError(
  421. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  422. )
  423. if (
  424. not interpolate_pos_encoding
  425. and tf.executing_eagerly()
  426. and (height != self.image_size[0] or width != self.image_size[1])
  427. ):
  428. raise ValueError(
  429. f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
  430. )
  431. # When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format.
  432. # So change the input format from `NCHW` to `NHWC`.
  433. # shape = (batch_size, in_height, in_width, in_channels=num_channels)
  434. pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
  435. projection = self.projection(pixel_values)
  436. # Change the 2D spatial dimensions to a single temporal dimension.
  437. # shape = (batch_size, num_patches, out_channels=embed_dim)
  438. num_patches = (width // self.patch_size[1]) * (height // self.patch_size[0])
  439. # In the TFGroupViTVisionEmbeddings the embeddings from this layer will be layer normalized
  440. # LayerNormalization layer needs to have static last dimension (otherwise the test_keras_save_load fails with symbolic tensors)
  441. # This is why we have used the hidden_size in the reshape method
  442. embeddings = tf.reshape(tensor=projection, shape=(batch_size, num_patches, self.hidden_size))
  443. return embeddings
  444. def build(self, input_shape=None):
  445. if self.built:
  446. return
  447. self.built = True
  448. if getattr(self, "projection", None) is not None:
  449. with tf.name_scope(self.projection.name):
  450. self.projection.build([None, None, None, self.num_channels])
  451. # Adapted from transformers.vit.modeling_tf_vit.TFViTEmbeddings
  452. class TFGroupViTVisionEmbeddings(keras.layers.Layer):
  453. """
  454. Construct the position and patch embeddings.
  455. """
  456. def __init__(self, config: GroupViTVisionConfig, **kwargs):
  457. super().__init__(**kwargs)
  458. self.patch_embeddings = TFGroupViTPatchEmbeddings(config, name="patch_embeddings")
  459. self.dropout = keras.layers.Dropout(rate=config.dropout, name="dropout")
  460. self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
  461. self.config = config
  462. def build(self, input_shape=None):
  463. num_patches = self.patch_embeddings.num_patches
  464. self.position_embeddings = self.add_weight(
  465. shape=(1, num_patches, self.config.hidden_size),
  466. initializer="zeros",
  467. trainable=True,
  468. name="position_embeddings",
  469. )
  470. if self.built:
  471. return
  472. self.built = True
  473. if getattr(self, "patch_embeddings", None) is not None:
  474. with tf.name_scope(self.patch_embeddings.name):
  475. self.patch_embeddings.build(None)
  476. if getattr(self, "dropout", None) is not None:
  477. with tf.name_scope(self.dropout.name):
  478. self.dropout.build(None)
  479. if getattr(self, "layernorm", None) is not None:
  480. with tf.name_scope(self.layernorm.name):
  481. self.layernorm.build([None, None, self.config.hidden_size])
  482. def interpolate_pos_encoding(self, embeddings, height, width) -> tf.Tensor:
  483. """
  484. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
  485. resolution images.
  486. Source:
  487. https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
  488. """
  489. batch_size, num_patches, dim = shape_list(embeddings)
  490. num_positions = shape_list(self.position_embeddings)[1]
  491. if num_patches == num_positions and height == width:
  492. return self.position_embeddings
  493. patch_pos_embed = self.position_embeddings
  494. h0 = height // self.config.patch_size
  495. w0 = width // self.config.patch_size
  496. patch_pos_embed = tf.image.resize(
  497. images=tf.reshape(
  498. patch_pos_embed, shape=(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
  499. ),
  500. size=(h0, w0),
  501. method="bicubic",
  502. )
  503. patch_pos_embed = tf.reshape(tensor=patch_pos_embed, shape=(1, -1, dim))
  504. return patch_pos_embed
  505. def call(
  506. self, pixel_values: tf.Tensor, interpolate_pos_encoding: bool = False, training: bool = False
  507. ) -> tf.Tensor:
  508. _, _, height, width = shape_list(pixel_values)
  509. embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  510. embeddings = self.layernorm(embeddings)
  511. # add positional encoding to each token
  512. if interpolate_pos_encoding:
  513. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  514. else:
  515. embeddings = embeddings + self.position_embeddings
  516. embeddings = self.dropout(embeddings)
  517. return embeddings
  518. # Copied from transformers.models.clip.modeling_tf_clip.TFCLIPTextEmbeddings with CLIP->GroupViT
  519. class TFGroupViTTextEmbeddings(keras.layers.Layer):
  520. def __init__(self, config: GroupViTTextConfig, **kwargs):
  521. super().__init__(**kwargs)
  522. self.embed_dim = config.hidden_size
  523. self.config = config
  524. def build(self, input_shape: tf.TensorShape = None):
  525. with tf.name_scope("token_embedding"):
  526. self.weight = self.add_weight(
  527. shape=(self.config.vocab_size, self.embed_dim),
  528. initializer=get_initializer(self.config.initializer_factor * self.config.initializer_range),
  529. trainable=True,
  530. name="weight",
  531. )
  532. with tf.name_scope("position_embedding"):
  533. self.position_embedding = self.add_weight(
  534. shape=(self.config.max_position_embeddings, self.embed_dim),
  535. initializer=get_initializer(self.config.initializer_factor * self.config.initializer_range),
  536. trainable=True,
  537. name="embeddings",
  538. )
  539. super().build(input_shape)
  540. def call(
  541. self,
  542. input_ids: tf.Tensor = None,
  543. position_ids: tf.Tensor = None,
  544. inputs_embeds: tf.Tensor = None,
  545. ) -> tf.Tensor:
  546. """
  547. Applies embedding based on inputs tensor.
  548. Returns:
  549. final_embeddings (`tf.Tensor`): output embedding tensor.
  550. """
  551. if input_ids is None and inputs_embeds is None:
  552. raise ValueError("You have to specify either input_ids or inputs_embeds")
  553. if inputs_embeds is None:
  554. check_embeddings_within_bounds(input_ids, self.config.vocab_size)
  555. inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
  556. input_shape = shape_list(inputs_embeds)[:-1]
  557. if position_ids is None:
  558. position_ids = tf.expand_dims(tf.range(start=0, limit=input_shape[-1]), axis=0)
  559. position_embeds = tf.gather(params=self.position_embedding, indices=position_ids)
  560. position_embeds = tf.tile(input=position_embeds, multiples=(input_shape[0], 1, 1))
  561. final_embeddings = inputs_embeds + position_embeds
  562. return final_embeddings
  563. class TFGroupViTStage(keras.layers.Layer):
  564. """This corresponds to the `GroupingLayer` class in the GroupViT implementation."""
  565. def __init__(
  566. self,
  567. config: GroupViTVisionConfig,
  568. depth: int,
  569. num_prev_group_token: int,
  570. num_group_token: int,
  571. num_output_group: int,
  572. **kwargs,
  573. ):
  574. super().__init__(**kwargs)
  575. self.config = config
  576. self.depth = depth
  577. self.num_group_token = num_group_token
  578. self.layers = [TFGroupViTEncoderLayer(config, name=f"layers_._{i}") for i in range(depth)]
  579. if num_group_token > 0:
  580. self.downsample = TFGroupViTTokenAssign(
  581. config=config,
  582. num_group_token=num_group_token,
  583. num_output_group=num_output_group,
  584. name="downsample",
  585. )
  586. else:
  587. self.downsample = None
  588. if num_prev_group_token > 0 and num_group_token > 0:
  589. self.group_projector = [
  590. keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="group_projector.0"),
  591. TFGroupViTMixerMLP(
  592. config, num_prev_group_token, config.hidden_size // 2, num_group_token, name="group_projector.1"
  593. ),
  594. ]
  595. else:
  596. self.group_projector = None
  597. def build(self, input_shape=None):
  598. if self.num_group_token > 0:
  599. self.group_token = self.add_weight(
  600. shape=(1, self.num_group_token, self.config.hidden_size),
  601. initializer="zeros",
  602. trainable=True,
  603. name="group_token",
  604. )
  605. else:
  606. self.group_token = None
  607. if self.built:
  608. return
  609. self.built = True
  610. if getattr(self, "downsample", None) is not None:
  611. with tf.name_scope(self.downsample.name):
  612. self.downsample.build(None)
  613. if getattr(self, "layers", None) is not None:
  614. for layer in self.layers:
  615. with tf.name_scope(layer.name):
  616. layer.build(None)
  617. if getattr(self, "group_projector", None) is not None:
  618. with tf.name_scope(self.group_projector[0].name):
  619. self.group_projector[0].build([None, None, self.config.hidden_size])
  620. with tf.name_scope(self.group_projector[1].name):
  621. self.group_projector[1].build(None)
  622. @property
  623. def with_group_token(self):
  624. return self.group_token is not None
  625. def split_x(self, x: tf.Tensor) -> tf.Tensor:
  626. if self.with_group_token:
  627. return x[:, : -self.num_group_token], x[:, -self.num_group_token :]
  628. else:
  629. return x, None
  630. def concat_x(self, x: tf.Tensor, group_token: tf.Tensor | None = None) -> tf.Tensor:
  631. if group_token is None:
  632. return x
  633. return tf.concat([x, group_token], axis=1)
  634. def call(
  635. self,
  636. hidden_states: tf.Tensor,
  637. prev_group_token: tf.Tensor | None = None,
  638. output_attentions: bool = False,
  639. training: bool = False,
  640. ) -> Tuple[tf.Tensor]:
  641. """
  642. Args:
  643. hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  644. attention_mask (`tf.Tensor`): attention mask of size
  645. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  646. `(config.encoder_attention_heads,)`.
  647. output_attentions (`bool`, *optional*):
  648. Whether or not to return the grouping tensors of Grouping block.
  649. """
  650. if self.with_group_token:
  651. group_token = tf.tile(self.group_token, multiples=(shape_list(hidden_states)[0], 1, 1))
  652. if self.group_projector is not None:
  653. for layer in self.group_projector:
  654. prev_group_token = layer(prev_group_token)
  655. group_token = group_token + prev_group_token
  656. else:
  657. group_token = None
  658. x = hidden_states
  659. cat_x = self.concat_x(x, group_token)
  660. for layer in self.layers:
  661. layer_out = layer(
  662. cat_x,
  663. attention_mask=None,
  664. causal_attention_mask=None,
  665. output_attentions=None,
  666. )
  667. cat_x = layer_out[0]
  668. x, group_token = self.split_x(cat_x)
  669. attention = None
  670. if self.downsample is not None:
  671. x, attention = self.downsample(x, group_token)
  672. outputs = (x, group_token)
  673. if output_attentions:
  674. outputs = outputs + (attention,)
  675. return outputs
  676. class TFGroupViTMLP(keras.layers.Layer):
  677. def __init__(
  678. self,
  679. config: GroupViTVisionConfig,
  680. hidden_size: Optional[int] = None,
  681. intermediate_size: Optional[int] = None,
  682. output_size: Optional[int] = None,
  683. **kwargs,
  684. ):
  685. super().__init__(**kwargs)
  686. self.config = config
  687. self.activation_fn = get_tf_activation(config.hidden_act)
  688. hidden_size = hidden_size if hidden_size is not None else config.hidden_size
  689. intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size
  690. output_size = output_size if output_size is not None else hidden_size
  691. self.fc1 = keras.layers.Dense(intermediate_size, name="fc1")
  692. self.fc2 = keras.layers.Dense(output_size, name="fc2")
  693. self.intermediate_size = intermediate_size
  694. self.hidden_size = hidden_size
  695. def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
  696. hidden_states = self.fc1(hidden_states)
  697. hidden_states = self.activation_fn(hidden_states)
  698. hidden_states = self.fc2(hidden_states)
  699. return hidden_states
  700. def build(self, input_shape=None):
  701. if self.built:
  702. return
  703. self.built = True
  704. if getattr(self, "fc1", None) is not None:
  705. with tf.name_scope(self.fc1.name):
  706. self.fc1.build([None, None, self.hidden_size])
  707. if getattr(self, "fc2", None) is not None:
  708. with tf.name_scope(self.fc2.name):
  709. self.fc2.build([None, None, self.intermediate_size])
  710. class TFGroupViTMixerMLP(TFGroupViTMLP):
  711. def call(self, x, training: bool = False):
  712. x = super().call(hidden_states=tf.transpose(x, perm=(0, 2, 1)))
  713. return tf.transpose(x, perm=(0, 2, 1))
  714. # Adapted from transformers.models.clip.modeling_tf_clip.TFCLIPAttention
  715. class TFGroupViTAttention(keras.layers.Layer):
  716. """Multi-headed attention from 'Attention Is All You Need' paper"""
  717. def __init__(self, config: GroupViTConfig, **kwargs):
  718. super().__init__(**kwargs)
  719. self.embed_dim = config.hidden_size
  720. self.num_attention_heads = config.num_attention_heads
  721. self.attention_head_size = self.embed_dim // self.num_attention_heads
  722. if self.attention_head_size * self.num_attention_heads != self.embed_dim:
  723. raise ValueError(
  724. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  725. f" {self.num_attention_heads})."
  726. )
  727. factor = config.initializer_factor
  728. in_proj_std = (self.embed_dim**-0.5) * ((2 * config.num_hidden_layers) ** -0.5) * factor
  729. out_proj_std = (self.embed_dim**-0.5) * factor
  730. self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
  731. self.q_proj = keras.layers.Dense(
  732. units=self.embed_dim, kernel_initializer=get_initializer(in_proj_std), name="q_proj"
  733. )
  734. self.k_proj = keras.layers.Dense(
  735. units=self.embed_dim, kernel_initializer=get_initializer(in_proj_std), name="k_proj"
  736. )
  737. self.v_proj = keras.layers.Dense(
  738. units=self.embed_dim, kernel_initializer=get_initializer(in_proj_std), name="v_proj"
  739. )
  740. self.dropout = keras.layers.Dropout(rate=config.attention_dropout)
  741. self.out_proj = keras.layers.Dense(
  742. units=self.embed_dim, kernel_initializer=get_initializer(out_proj_std), name="out_proj"
  743. )
  744. # Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention.transpose_for_scores
  745. def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
  746. # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
  747. tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
  748. # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
  749. return tf.transpose(tensor, perm=[0, 2, 1, 3])
  750. def call(
  751. self,
  752. hidden_states: tf.Tensor,
  753. attention_mask: tf.Tensor = None,
  754. causal_attention_mask: tf.Tensor = None,
  755. output_attentions: bool = None,
  756. encoder_hidden_states: tf.Tensor = None,
  757. training: bool = False,
  758. ) -> Tuple[tf.Tensor]:
  759. """Input shape: Batch x Time x Channel"""
  760. batch_size = shape_list(hidden_states)[0]
  761. is_cross_attention = encoder_hidden_states is not None
  762. mixed_query_layer = self.q_proj(inputs=hidden_states)
  763. if is_cross_attention:
  764. mixed_key_layer = self.k_proj(inputs=encoder_hidden_states)
  765. mixed_value_layer = self.v_proj(inputs=encoder_hidden_states)
  766. else:
  767. mixed_key_layer = self.k_proj(inputs=hidden_states)
  768. mixed_value_layer = self.v_proj(inputs=hidden_states)
  769. query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
  770. key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
  771. value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
  772. # Take the dot product between "query" and "key" to get the raw attention scores.
  773. # (batch size, num_heads, seq_len_q, seq_len_k)
  774. attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
  775. dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
  776. attention_scores = tf.divide(attention_scores, dk)
  777. # apply the causal_attention_mask first
  778. if causal_attention_mask is not None:
  779. # Apply the causal attention mask (precomputed for all layers in TFCLIPModel call() function)
  780. attention_scores = tf.add(attention_scores, causal_attention_mask)
  781. if attention_mask is not None:
  782. # Apply the attention mask (precomputed for all layers in TFCLIPModel call() function)
  783. attention_scores = tf.add(attention_scores, attention_mask)
  784. # Normalize the attention scores to probabilities.
  785. _attention_probs = stable_softmax(logits=attention_scores, axis=-1)
  786. # This is actually dropping out entire tokens to attend to, which might
  787. # seem a bit unusual, but is taken from the original Transformer paper.
  788. attention_probs = self.dropout(inputs=_attention_probs)
  789. attention_output = tf.matmul(attention_probs, value_layer)
  790. attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
  791. # (batch_size, seq_len_q, embed_dim)
  792. attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.embed_dim))
  793. attention_output = self.out_proj(attention_output)
  794. # In TFBert, attention weights are returned after dropout.
  795. # However, in CLIP, they are returned before dropout.
  796. outputs = (attention_output, _attention_probs) if output_attentions else (attention_output,)
  797. return outputs
  798. def build(self, input_shape=None):
  799. if self.built:
  800. return
  801. self.built = True
  802. if getattr(self, "q_proj", None) is not None:
  803. with tf.name_scope(self.q_proj.name):
  804. self.q_proj.build([None, None, self.embed_dim])
  805. if getattr(self, "k_proj", None) is not None:
  806. with tf.name_scope(self.k_proj.name):
  807. self.k_proj.build([None, None, self.embed_dim])
  808. if getattr(self, "v_proj", None) is not None:
  809. with tf.name_scope(self.v_proj.name):
  810. self.v_proj.build([None, None, self.embed_dim])
  811. if getattr(self, "out_proj", None) is not None:
  812. with tf.name_scope(self.out_proj.name):
  813. self.out_proj.build([None, None, self.embed_dim])
  814. # Copied from transformers.models.clip.modeling_tf_clip.TFCLIPEncoderLayer with CLIP->GroupViT
  815. class TFGroupViTEncoderLayer(keras.layers.Layer):
  816. def __init__(self, config: GroupViTConfig, **kwargs):
  817. super().__init__(**kwargs)
  818. self.embed_dim = config.hidden_size
  819. self.self_attn = TFGroupViTAttention(config, name="self_attn")
  820. self.layer_norm1 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm1")
  821. self.mlp = TFGroupViTMLP(config, name="mlp")
  822. self.layer_norm2 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2")
  823. def call(
  824. self,
  825. hidden_states: tf.Tensor,
  826. attention_mask: tf.Tensor,
  827. causal_attention_mask: tf.Tensor,
  828. output_attentions: bool,
  829. training: bool = False,
  830. ) -> Tuple[tf.Tensor]:
  831. """
  832. Args:
  833. hidden_states (`tf.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  834. attention_mask (`tf.Tensor`): attention mask of size
  835. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  836. causal_attention_mask (`tf.Tensor`): causal attention mask of size
  837. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  838. output_attentions (`bool`):
  839. Whether or not to return the attentions tensors of all attention layers. See `outputs` under returned
  840. tensors for more detail.
  841. """
  842. residual = hidden_states
  843. hidden_states = self.layer_norm1(inputs=hidden_states)
  844. attention_outputs = self.self_attn(
  845. hidden_states=hidden_states,
  846. attention_mask=attention_mask,
  847. causal_attention_mask=causal_attention_mask,
  848. output_attentions=output_attentions,
  849. training=training,
  850. )
  851. hidden_states = attention_outputs[0]
  852. hidden_states = residual + hidden_states
  853. residual = hidden_states
  854. hidden_states = self.layer_norm2(inputs=hidden_states)
  855. hidden_states = self.mlp(hidden_states=hidden_states)
  856. hidden_states = residual + hidden_states
  857. outputs = (hidden_states,) + attention_outputs[1:] # add attentions if we output them
  858. return outputs
  859. def build(self, input_shape=None):
  860. if self.built:
  861. return
  862. self.built = True
  863. if getattr(self, "self_attn", None) is not None:
  864. with tf.name_scope(self.self_attn.name):
  865. self.self_attn.build(None)
  866. if getattr(self, "layer_norm1", None) is not None:
  867. with tf.name_scope(self.layer_norm1.name):
  868. self.layer_norm1.build([None, None, self.embed_dim])
  869. if getattr(self, "mlp", None) is not None:
  870. with tf.name_scope(self.mlp.name):
  871. self.mlp.build(None)
  872. if getattr(self, "layer_norm2", None) is not None:
  873. with tf.name_scope(self.layer_norm2.name):
  874. self.layer_norm2.build([None, None, self.embed_dim])
  875. # Adapted from transformers.models.clip.modeling_tf_clip.TFGroupViTTextEncoder
  876. class TFGroupViTTextEncoder(keras.layers.Layer):
  877. def __init__(self, config: GroupViTTextConfig, **kwargs):
  878. super().__init__(**kwargs)
  879. self.layers = [TFGroupViTEncoderLayer(config, name=f"layers_._{i}") for i in range(config.num_hidden_layers)]
  880. def call(
  881. self,
  882. hidden_states,
  883. attention_mask: tf.Tensor,
  884. causal_attention_mask: tf.Tensor,
  885. output_attentions: bool,
  886. output_hidden_states: bool,
  887. return_dict: bool,
  888. training: bool = False,
  889. ) -> Union[Tuple, TFBaseModelOutput]:
  890. encoder_states = () if output_hidden_states else None
  891. all_attentions = () if output_attentions else None
  892. for idx, encoder_layer in enumerate(self.layers):
  893. if output_hidden_states:
  894. encoder_states = encoder_states + (hidden_states,)
  895. layer_outputs = encoder_layer(
  896. hidden_states,
  897. attention_mask,
  898. causal_attention_mask,
  899. output_attentions=output_attentions,
  900. )
  901. hidden_states = layer_outputs[0]
  902. if output_attentions:
  903. all_attentions = all_attentions + (layer_outputs[1],)
  904. if output_hidden_states:
  905. encoder_states = encoder_states + (hidden_states,)
  906. if not return_dict:
  907. return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
  908. return TFBaseModelOutput(
  909. last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
  910. )
  911. def build(self, input_shape=None):
  912. if self.built:
  913. return
  914. self.built = True
  915. if getattr(self, "layers", None) is not None:
  916. for layer in self.layers:
  917. with tf.name_scope(layer.name):
  918. layer.build(None)
  919. class TFGroupViTVisionEncoder(keras.layers.Layer):
  920. def __init__(self, config: GroupViTVisionConfig, **kwargs) -> None:
  921. super().__init__(**kwargs)
  922. self.stages = [
  923. TFGroupViTStage(
  924. config=config,
  925. depth=config.depths[i],
  926. num_group_token=config.num_group_tokens[i],
  927. num_output_group=config.num_output_groups[i],
  928. num_prev_group_token=config.num_output_groups[i - 1] if i > 0 else 0,
  929. name=f"stages_._{i}",
  930. )
  931. for i in range(len(config.depths))
  932. ]
  933. def call(
  934. self,
  935. hidden_states: tf.Tensor,
  936. output_hidden_states: bool,
  937. output_attentions: bool,
  938. return_dict: bool,
  939. training: bool = False,
  940. ) -> Union[tuple, TFBaseModelOutput]:
  941. all_hidden_states = () if output_hidden_states else None
  942. all_groupings = () if output_attentions else None
  943. group_tokens = None
  944. for stage in self.stages:
  945. if output_hidden_states:
  946. all_hidden_states = all_hidden_states + (hidden_states,)
  947. layer_outputs = stage(hidden_states, group_tokens, output_attentions)
  948. hidden_states = layer_outputs[0]
  949. group_tokens = layer_outputs[1]
  950. if output_attentions and layer_outputs[2] is not None:
  951. all_groupings = all_groupings + (layer_outputs[2],)
  952. if output_hidden_states:
  953. all_hidden_states = all_hidden_states + (hidden_states,)
  954. if not return_dict:
  955. return tuple(v for v in [hidden_states, all_hidden_states, all_groupings] if v is not None)
  956. return TFBaseModelOutput(
  957. last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_groupings
  958. )
  959. def build(self, input_shape=None):
  960. if self.built:
  961. return
  962. self.built = True
  963. if getattr(self, "stages", None) is not None:
  964. for layer in self.stages:
  965. with tf.name_scope(layer.name):
  966. layer.build(None)
  967. # Copied from transformers.models.clip.modeling_tf_clip.TFCLIPTextTransformer with CLIPText->GroupViTText, CLIPEncoder->GroupViTTextEncoder
  968. class TFGroupViTTextTransformer(keras.layers.Layer):
  969. def __init__(self, config: GroupViTTextConfig, **kwargs):
  970. super().__init__(**kwargs)
  971. self.embeddings = TFGroupViTTextEmbeddings(config, name="embeddings")
  972. self.encoder = TFGroupViTTextEncoder(config, name="encoder")
  973. self.final_layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="final_layer_norm")
  974. # For `pooled_output` computation
  975. self.eos_token_id = config.eos_token_id
  976. self.embed_dim = config.hidden_size
  977. def call(
  978. self,
  979. input_ids: TFModelInputType,
  980. attention_mask: tf.Tensor,
  981. position_ids: tf.Tensor,
  982. output_attentions: bool,
  983. output_hidden_states: bool,
  984. return_dict: bool,
  985. training: bool = False,
  986. ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
  987. input_shape = shape_list(input_ids)
  988. embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids)
  989. batch_size, seq_length = input_shape
  990. # CLIP's text model uses causal mask, prepare it here.
  991. # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
  992. causal_attention_mask = self._build_causal_attention_mask(batch_size, seq_length, dtype=embedding_output.dtype)
  993. # check attention mask and invert
  994. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  995. attention_mask = _expand_mask(attention_mask)
  996. encoder_outputs = self.encoder(
  997. hidden_states=embedding_output,
  998. attention_mask=attention_mask,
  999. causal_attention_mask=causal_attention_mask,
  1000. output_attentions=output_attentions,
  1001. output_hidden_states=output_hidden_states,
  1002. return_dict=return_dict,
  1003. training=training,
  1004. )
  1005. sequence_output = encoder_outputs[0]
  1006. sequence_output = self.final_layer_norm(inputs=sequence_output)
  1007. if self.eos_token_id == 2:
  1008. # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
  1009. # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
  1010. # ------------------------------------------------------------
  1011. # text_embeds.shape = [batch_size, n_ctx, transformer.width]
  1012. # take features from the eot embedding (eot_token is the highest number in each sequence)
  1013. pooled_output = tf.gather_nd(
  1014. params=sequence_output,
  1015. indices=tf.stack(
  1016. values=(tf.range(input_shape[0], dtype=tf.int64), tf.math.argmax(input_ids, axis=-1)), axis=1
  1017. ),
  1018. )
  1019. else:
  1020. # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
  1021. pooled_output = tf.gather_nd(
  1022. params=sequence_output,
  1023. indices=tf.stack(
  1024. values=(
  1025. tf.range(input_shape[0], dtype=tf.int64),
  1026. tf.math.argmax(tf.cast(input_ids == self.eos_token_id, dtype=tf.int8), axis=-1),
  1027. ),
  1028. axis=1,
  1029. ),
  1030. )
  1031. if not return_dict:
  1032. return (sequence_output, pooled_output) + encoder_outputs[1:]
  1033. return TFBaseModelOutputWithPooling(
  1034. last_hidden_state=sequence_output,
  1035. pooler_output=pooled_output,
  1036. hidden_states=encoder_outputs.hidden_states,
  1037. attentions=encoder_outputs.attentions,
  1038. )
  1039. def _build_causal_attention_mask(self, batch_size, seq_length, dtype=tf.float32):
  1040. # It is possible with an unspecified sequence length for seq_length to be
  1041. # a runtime value, which is unsupported by tf.constant. Per the TensorFlow
  1042. # docs, tf.fill can handle runtime dynamic shapes:
  1043. # https://www.tensorflow.org/api_docs/python/tf/fill
  1044. diag = tf.cast(tf.fill((seq_length,), 0.0), dtype)
  1045. # set an additive 2D attention mask with all places being masked
  1046. to_mask = tf.cast(tf.fill((seq_length, seq_length), -10000.0), dtype)
  1047. # set diagonal & lower triangular parts to 0 (i.e. the places not to be masked)
  1048. # TIP: think the 2D matrix as the space of (query_seq, key_seq)
  1049. to_mask = tf.linalg.band_part(to_mask, 0, -1)
  1050. # to_mask = tf.linalg.band_part(to_mask, -1, 0)
  1051. to_mask = tf.linalg.set_diag(to_mask, diagonal=diag)
  1052. return tf.broadcast_to(input=to_mask, shape=(batch_size, 1, seq_length, seq_length))
  1053. def build(self, input_shape=None):
  1054. if self.built:
  1055. return
  1056. self.built = True
  1057. if getattr(self, "embeddings", None) is not None:
  1058. with tf.name_scope(self.embeddings.name):
  1059. self.embeddings.build(None)
  1060. if getattr(self, "encoder", None) is not None:
  1061. with tf.name_scope(self.encoder.name):
  1062. self.encoder.build(None)
  1063. if getattr(self, "final_layer_norm", None) is not None:
  1064. with tf.name_scope(self.final_layer_norm.name):
  1065. self.final_layer_norm.build([None, None, self.embed_dim])
  1066. # Adapted from transformers.models.clip.modeling_tf_clip.TFCLIPVisionTransformer
  1067. class TFGroupViTVisionTransformer(keras.layers.Layer):
  1068. def __init__(self, config: GroupViTVisionConfig, **kwargs):
  1069. super().__init__(**kwargs)
  1070. self.embeddings = TFGroupViTVisionEmbeddings(config, name="embeddings")
  1071. self.encoder = TFGroupViTVisionEncoder(config, name="encoder")
  1072. self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
  1073. self.embed_dim = config.hidden_size
  1074. def call(
  1075. self,
  1076. pixel_values: TFModelInputType,
  1077. output_attentions: bool,
  1078. output_hidden_states: bool,
  1079. return_dict: bool,
  1080. training: bool = False,
  1081. ) -> Union[Tuple, TFBaseModelOutputWithPooling]:
  1082. embedding_output = self.embeddings(pixel_values)
  1083. encoder_outputs = self.encoder(
  1084. hidden_states=embedding_output,
  1085. output_hidden_states=output_hidden_states,
  1086. output_attentions=output_attentions,
  1087. return_dict=return_dict,
  1088. )
  1089. last_hidden_state = encoder_outputs[0]
  1090. # normalize the last hidden state
  1091. last_hidden_state = self.layernorm(last_hidden_state)
  1092. pooled_output = tf.math.reduce_mean(last_hidden_state, axis=1)
  1093. if not return_dict:
  1094. return (last_hidden_state, pooled_output) + encoder_outputs[1:]
  1095. return TFBaseModelOutputWithPooling(
  1096. last_hidden_state=last_hidden_state,
  1097. pooler_output=pooled_output,
  1098. hidden_states=encoder_outputs.hidden_states,
  1099. attentions=encoder_outputs.attentions,
  1100. )
  1101. def build(self, input_shape=None):
  1102. if self.built:
  1103. return
  1104. self.built = True
  1105. if getattr(self, "embeddings", None) is not None:
  1106. with tf.name_scope(self.embeddings.name):
  1107. self.embeddings.build(None)
  1108. if getattr(self, "encoder", None) is not None:
  1109. with tf.name_scope(self.encoder.name):
  1110. self.encoder.build(None)
  1111. if getattr(self, "layernorm", None) is not None:
  1112. with tf.name_scope(self.layernorm.name):
  1113. self.layernorm.build([None, None, self.embed_dim])
  1114. @keras_serializable
  1115. # Copied from transformers.models.clip.modeling_tf_clip.TFCLIPTextMainLayer with CLIP->GroupViT
  1116. class TFGroupViTTextMainLayer(keras.layers.Layer):
  1117. config_class = GroupViTTextConfig
  1118. def __init__(self, config: GroupViTTextConfig, **kwargs):
  1119. super().__init__(**kwargs)
  1120. self.config = config
  1121. self.text_model = TFGroupViTTextTransformer(config, name="text_model")
  1122. def get_input_embeddings(self) -> keras.layers.Layer:
  1123. return self.text_model.embeddings
  1124. def set_input_embeddings(self, value: tf.Variable):
  1125. self.text_model.embeddings.weight = value
  1126. self.text_model.embeddings.vocab_size = shape_list(value)[0]
  1127. @unpack_inputs
  1128. def call(
  1129. self,
  1130. input_ids: TFModelInputType | None = None,
  1131. attention_mask: np.ndarray | tf.Tensor | None = None,
  1132. position_ids: np.ndarray | tf.Tensor | None = None,
  1133. output_attentions: Optional[bool] = None,
  1134. output_hidden_states: Optional[bool] = None,
  1135. return_dict: Optional[bool] = None,
  1136. training: bool = False,
  1137. ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
  1138. if input_ids is None:
  1139. raise ValueError("You have to specify input_ids")
  1140. input_shape = shape_list(input_ids)
  1141. if attention_mask is None:
  1142. attention_mask = tf.fill(dims=input_shape, value=1)
  1143. text_model_outputs = self.text_model(
  1144. input_ids=input_ids,
  1145. attention_mask=attention_mask,
  1146. position_ids=position_ids,
  1147. output_attentions=output_attentions,
  1148. output_hidden_states=output_hidden_states,
  1149. return_dict=return_dict,
  1150. training=training,
  1151. )
  1152. return text_model_outputs
  1153. def build(self, input_shape=None):
  1154. if self.built:
  1155. return
  1156. self.built = True
  1157. if getattr(self, "text_model", None) is not None:
  1158. with tf.name_scope(self.text_model.name):
  1159. self.text_model.build(None)
  1160. @keras_serializable
  1161. # Copied from transformers.models.clip.modeling_tf_clip.TFCLIPVisionMainLayer with CLIP->GroupViT
  1162. class TFGroupViTVisionMainLayer(keras.layers.Layer):
  1163. config_class = GroupViTVisionConfig
  1164. def __init__(self, config: GroupViTVisionConfig, **kwargs):
  1165. super().__init__(**kwargs)
  1166. self.config = config
  1167. self.vision_model = TFGroupViTVisionTransformer(config, name="vision_model")
  1168. def get_input_embeddings(self) -> keras.layers.Layer:
  1169. return self.vision_model.embeddings
  1170. @unpack_inputs
  1171. def call(
  1172. self,
  1173. pixel_values: TFModelInputType | None = None,
  1174. output_attentions: Optional[bool] = None,
  1175. output_hidden_states: Optional[bool] = None,
  1176. return_dict: Optional[bool] = None,
  1177. training: bool = False,
  1178. ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
  1179. if pixel_values is None:
  1180. raise ValueError("You have to specify pixel_values")
  1181. vision_model_outputs = self.vision_model(
  1182. pixel_values=pixel_values,
  1183. output_attentions=output_attentions,
  1184. output_hidden_states=output_hidden_states,
  1185. return_dict=return_dict,
  1186. training=training,
  1187. )
  1188. return vision_model_outputs
  1189. def build(self, input_shape=None):
  1190. if self.built:
  1191. return
  1192. self.built = True
  1193. if getattr(self, "vision_model", None) is not None:
  1194. with tf.name_scope(self.vision_model.name):
  1195. self.vision_model.build(None)
  1196. @keras_serializable
  1197. # Adapted from transformers.models.clip.modeling_tf_clip.TFCLIPMainLayer
  1198. class TFGroupViTMainLayer(keras.layers.Layer):
  1199. config_class = GroupViTConfig
  1200. def __init__(self, config: GroupViTConfig, **kwargs):
  1201. super().__init__(**kwargs)
  1202. if not isinstance(config.text_config, GroupViTTextConfig):
  1203. raise TypeError(
  1204. "config.text_config is expected to be of type GroupViTTextConfig but is of type"
  1205. f" {type(config.text_config)}."
  1206. )
  1207. if not isinstance(config.vision_config, GroupViTVisionConfig):
  1208. raise TypeError(
  1209. "config.vision_config is expected to be of type GroupViTVisionConfig but is of type"
  1210. f" {type(config.vision_config)}."
  1211. )
  1212. self.config = config
  1213. text_config = config.text_config
  1214. vision_config = config.vision_config
  1215. self.projection_dim = config.projection_dim
  1216. self.projection_intermediate_dim = config.projection_intermediate_dim
  1217. self.text_embed_dim = text_config.hidden_size
  1218. self.vision_embed_dim = vision_config.hidden_size
  1219. self.text_model = TFGroupViTTextTransformer(text_config, name="text_model")
  1220. self.vision_model = TFGroupViTVisionTransformer(vision_config, name="vision_model")
  1221. self.visual_projection = [
  1222. keras.layers.Dense(self.projection_intermediate_dim, name="visual_projection.0"),
  1223. keras.layers.BatchNormalization(name="visual_projection.1", momentum=0.9, epsilon=1e-5),
  1224. keras.layers.ReLU(name="visual_projection.2"),
  1225. keras.layers.Dense(self.projection_dim, name="visual_projection.3"),
  1226. ]
  1227. self.text_projection = [
  1228. keras.layers.Dense(self.projection_intermediate_dim, name="text_projection.0"),
  1229. keras.layers.BatchNormalization(name="text_projection.1", momentum=0.9, epsilon=1e-5),
  1230. keras.layers.ReLU(name="text_projection.2"),
  1231. keras.layers.Dense(self.projection_dim, name="text_projection.3"),
  1232. ]
  1233. def build(self, input_shape=None):
  1234. self.logit_scale = self.add_weight(
  1235. shape=(1,),
  1236. initializer=keras.initializers.Constant(self.config.logit_scale_init_value),
  1237. trainable=True,
  1238. name="logit_scale",
  1239. )
  1240. if self.built:
  1241. return
  1242. self.built = True
  1243. if getattr(self, "text_model", None) is not None:
  1244. with tf.name_scope(self.text_model.name):
  1245. self.text_model.build(None)
  1246. if getattr(self, "vision_model", None) is not None:
  1247. with tf.name_scope(self.vision_model.name):
  1248. self.vision_model.build(None)
  1249. if getattr(self, "visual_projection", None) is not None:
  1250. with tf.name_scope(self.visual_projection[0].name):
  1251. self.visual_projection[0].build([None, None, None, self.vision_embed_dim])
  1252. with tf.name_scope(self.visual_projection[1].name):
  1253. self.visual_projection[1].build((None, self.projection_intermediate_dim))
  1254. with tf.name_scope(self.visual_projection[3].name):
  1255. self.visual_projection[3].build([None, None, None, self.projection_intermediate_dim])
  1256. if getattr(self, "text_projection", None) is not None:
  1257. with tf.name_scope(self.text_projection[0].name):
  1258. self.text_projection[0].build([None, None, None, self.text_embed_dim])
  1259. with tf.name_scope(self.text_projection[1].name):
  1260. self.text_projection[1].build((None, self.projection_intermediate_dim))
  1261. with tf.name_scope(self.text_projection[3].name):
  1262. self.text_projection[3].build([None, None, None, self.projection_intermediate_dim])
  1263. @unpack_inputs
  1264. def get_text_features(
  1265. self,
  1266. input_ids: TFModelInputType | None = None,
  1267. attention_mask: np.ndarray | tf.Tensor | None = None,
  1268. position_ids: np.ndarray | tf.Tensor | None = None,
  1269. output_attentions: Optional[bool] = None,
  1270. output_hidden_states: Optional[bool] = None,
  1271. return_dict: Optional[bool] = None,
  1272. training: bool = False,
  1273. ) -> tf.Tensor:
  1274. if input_ids is None:
  1275. raise ValueError("You have to specify either input_ids")
  1276. input_shape = shape_list(input_ids)
  1277. if attention_mask is None:
  1278. attention_mask = tf.fill(dims=input_shape, value=1)
  1279. text_outputs = self.text_model(
  1280. input_ids=input_ids,
  1281. attention_mask=attention_mask,
  1282. position_ids=position_ids,
  1283. output_attentions=output_attentions,
  1284. output_hidden_states=output_hidden_states,
  1285. return_dict=return_dict,
  1286. training=training,
  1287. )
  1288. pooled_output = text_outputs[1]
  1289. for layer in self.text_projection:
  1290. pooled_output = layer(pooled_output)
  1291. text_features = pooled_output
  1292. return text_features
  1293. @unpack_inputs
  1294. def get_image_features(
  1295. self,
  1296. pixel_values: TFModelInputType | None = None,
  1297. output_attentions: Optional[bool] = None,
  1298. output_hidden_states: Optional[bool] = None,
  1299. return_dict: Optional[bool] = None,
  1300. training: bool = False,
  1301. ) -> tf.Tensor:
  1302. if pixel_values is None:
  1303. raise ValueError("You have to specify pixel_values")
  1304. vision_outputs = self.vision_model(
  1305. pixel_values=pixel_values,
  1306. output_attentions=output_attentions,
  1307. output_hidden_states=output_hidden_states,
  1308. return_dict=return_dict,
  1309. training=training,
  1310. )
  1311. pooled_output = vision_outputs[1]
  1312. for layer in self.visual_projection:
  1313. pooled_output = layer(pooled_output)
  1314. image_features = pooled_output
  1315. return image_features
  1316. @unpack_inputs
  1317. def call(
  1318. self,
  1319. input_ids: TFModelInputType | None = None,
  1320. pixel_values: TFModelInputType | None = None,
  1321. attention_mask: np.ndarray | tf.Tensor | None = None,
  1322. position_ids: np.ndarray | tf.Tensor | None = None,
  1323. return_loss: Optional[bool] = None,
  1324. output_attentions: Optional[bool] = None,
  1325. output_hidden_states: Optional[bool] = None,
  1326. output_segmentation: Optional[bool] = None,
  1327. return_dict: Optional[bool] = None,
  1328. training: bool = False,
  1329. ) -> Union[TFGroupViTModelOutput, Tuple[tf.Tensor]]:
  1330. if input_ids is None:
  1331. raise ValueError("You have to specify either input_ids")
  1332. if pixel_values is None:
  1333. raise ValueError("You have to specify pixel_values")
  1334. input_shape = shape_list(input_ids)
  1335. if attention_mask is None:
  1336. attention_mask = tf.fill(dims=input_shape, value=1)
  1337. if output_segmentation:
  1338. output_attentions = True
  1339. vision_outputs = self.vision_model(
  1340. pixel_values=pixel_values,
  1341. output_attentions=output_attentions,
  1342. output_hidden_states=output_hidden_states,
  1343. return_dict=return_dict,
  1344. training=training,
  1345. )
  1346. text_outputs = self.text_model(
  1347. input_ids=input_ids,
  1348. attention_mask=attention_mask,
  1349. position_ids=position_ids,
  1350. output_attentions=output_attentions,
  1351. output_hidden_states=output_hidden_states,
  1352. return_dict=return_dict,
  1353. training=training,
  1354. )
  1355. image_embeds = vision_outputs[1]
  1356. for layer in self.visual_projection:
  1357. image_embeds = layer(image_embeds)
  1358. text_embeds = text_outputs[1]
  1359. for layer in self.text_projection:
  1360. text_embeds = layer(text_embeds)
  1361. # normalized features
  1362. image_embeds = image_embeds / tf.norm(image_embeds, axis=-1, keepdims=True)
  1363. text_embeds = text_embeds / tf.norm(text_embeds, axis=-1, keepdims=True)
  1364. # cosine similarity as logits
  1365. logit_scale = tf.math.exp(self.logit_scale)
  1366. logits_per_text = tf.matmul(text_embeds, image_embeds, transpose_b=True) * logit_scale
  1367. logits_per_image = tf.transpose(logits_per_text)
  1368. seg_logits = None
  1369. if output_segmentation:
  1370. # grouped features
  1371. # [batch_size_image, num_group, hidden_size]
  1372. image_group_embeds = vision_outputs[0]
  1373. # [batch_size_image*num_group, hidden_size]
  1374. image_group_embeds = tf.reshape(image_group_embeds, shape=(-1, shape_list(image_group_embeds)[-1]))
  1375. for layer in self.visual_projection:
  1376. image_group_embeds = layer(image_group_embeds)
  1377. if output_hidden_states:
  1378. attentions = vision_outputs[3]
  1379. else:
  1380. attentions = vision_outputs[2]
  1381. # [batch_size_image, num_group, height, width]
  1382. grouping = get_grouping_from_attentions(attentions, pixel_values.shape[2:])
  1383. # normalized features
  1384. image_group_embeds = image_group_embeds / tf.norm(
  1385. tensor=image_group_embeds, ord="euclidean", axis=-1, keepdims=True
  1386. )
  1387. # [batch_size_image x num_group, batch_size_text]
  1388. logits_per_image_group = tf.matmul(image_group_embeds, text_embeds, transpose_b=True) * logit_scale
  1389. # [batch_size_image, batch_size_text, num_group]
  1390. logits_per_image_group = tf.reshape(
  1391. logits_per_image_group, shape=(image_embeds.shape[0], -1, text_embeds.shape[0])
  1392. )
  1393. logits_per_image_group = tf.transpose(logits_per_image_group, perm=(0, 2, 1))
  1394. # [batch_size_image, batch_size_text, height x width]
  1395. flatten_grouping = tf.reshape(grouping, shape=(shape_list(grouping)[0], shape_list(grouping)[1], -1))
  1396. # [batch_size_image, batch_size_text, height, width]
  1397. seg_logits = tf.matmul(logits_per_image_group, flatten_grouping) * logit_scale
  1398. seg_logits = tf.reshape(
  1399. seg_logits, shape=(seg_logits.shape[0], seg_logits.shape[1], grouping.shape[2], grouping.shape[3])
  1400. )
  1401. loss = None
  1402. if return_loss:
  1403. loss = groupvit_loss(logits_per_text)[None, ...]
  1404. if not return_dict:
  1405. if seg_logits is not None:
  1406. output = (
  1407. logits_per_image,
  1408. logits_per_text,
  1409. seg_logits,
  1410. text_embeds,
  1411. image_embeds,
  1412. text_outputs,
  1413. vision_outputs,
  1414. )
  1415. else:
  1416. output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
  1417. return ((loss,) + output) if loss is not None else output
  1418. return TFGroupViTModelOutput(
  1419. loss=loss,
  1420. logits_per_image=logits_per_image,
  1421. logits_per_text=logits_per_text,
  1422. segmentation_logits=seg_logits,
  1423. text_embeds=text_embeds,
  1424. image_embeds=image_embeds,
  1425. text_model_output=text_outputs,
  1426. vision_model_output=vision_outputs,
  1427. )
  1428. class TFGroupViTPreTrainedModel(TFPreTrainedModel):
  1429. """
  1430. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  1431. models.
  1432. """
  1433. config_class = GroupViTConfig
  1434. base_model_prefix = "groupvit"
  1435. GROUPVIT_START_DOCSTRING = r"""
  1436. This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
  1437. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  1438. etc.)
  1439. This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
  1440. as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
  1441. behavior.
  1442. <Tip>
  1443. TF 2.0 models accepts two formats as inputs:
  1444. - having all inputs as keyword arguments (like PyTorch models), or
  1445. - having all inputs as a list, tuple or dict in the first positional arguments.
  1446. This second option is useful when using [`keras.Model.fit`] method which currently requires having all the
  1447. tensors in the first argument of the model call function: `model(inputs)`.
  1448. If you choose this second option, there are three possibilities you can use to gather all the input Tensors in the
  1449. first positional argument :
  1450. - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
  1451. - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
  1452. `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
  1453. - a dictionary with one or several input Tensors associated to the input names given in the docstring:
  1454. `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
  1455. </Tip>
  1456. Args:
  1457. config ([`GroupViTConfig`]): Model configuration class with all the parameters of the model.
  1458. Initializing with a config file does not load the weights associated with the model, only the
  1459. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  1460. """
  1461. GROUPVIT_TEXT_INPUTS_DOCSTRING = r"""
  1462. Args:
  1463. input_ids (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `({0})`):
  1464. Indices of input sequence tokens in the vocabulary.
  1465. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
  1466. [`PreTrainedTokenizer.encode`] for details.
  1467. [What are input IDs?](../glossary#input-ids)
  1468. attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
  1469. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  1470. - 1 for tokens that are **not masked**,
  1471. - 0 for tokens that are **masked**.
  1472. [What are attention masks?](../glossary#attention-mask)
  1473. position_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
  1474. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  1475. config.max_position_embeddings - 1]`.
  1476. [What are position IDs?](../glossary#position-ids)
  1477. output_attentions (`bool`, *optional*):
  1478. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  1479. tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
  1480. config will be used instead.
  1481. output_hidden_states (`bool`, *optional*):
  1482. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  1483. more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
  1484. used instead.
  1485. return_dict (`bool`, *optional*):
  1486. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
  1487. eager mode, in graph mode the value will always be set to True.
  1488. training (`bool`, *optional*, defaults to `False``):
  1489. Whether or not to use the model in training mode (some modules like dropout modules have different
  1490. behaviors between training and evaluation).
  1491. """
  1492. GROUPVIT_VISION_INPUTS_DOCSTRING = r"""
  1493. Args:
  1494. pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]`, `Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
  1495. Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
  1496. [`CLIPImageProcessor.__call__`] for details.
  1497. output_attentions (`bool`, *optional*):
  1498. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  1499. tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
  1500. config will be used instead.
  1501. output_hidden_states (`bool`, *optional*):
  1502. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  1503. more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
  1504. used instead.
  1505. return_dict (`bool`, *optional*):
  1506. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
  1507. eager mode, in graph mode the value will always be set to True.
  1508. training (`bool`, *optional*, defaults to `False``):
  1509. Whether or not to use the model in training mode (some modules like dropout modules have different
  1510. behaviors between training and evaluation).
  1511. """
  1512. GROUPVIT_INPUTS_DOCSTRING = r"""
  1513. Args:
  1514. input_ids (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `({0})`):
  1515. Indices of input sequence tokens in the vocabulary.
  1516. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
  1517. [`PreTrainedTokenizer.encode`] for details.
  1518. [What are input IDs?](../glossary#input-ids)
  1519. pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` `Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
  1520. Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
  1521. [`CLIPImageProcessor.__call__`] for details.
  1522. attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
  1523. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  1524. - 1 for tokens that are **not masked**,
  1525. - 0 for tokens that are **masked**.
  1526. [What are attention masks?](../glossary#attention-mask)
  1527. position_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
  1528. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  1529. config.max_position_embeddings - 1]`.
  1530. [What are position IDs?](../glossary#position-ids)
  1531. return_loss (`bool`, *optional*):
  1532. Whether or not to return the contrastive loss.
  1533. output_attentions (`bool`, *optional*):
  1534. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  1535. tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
  1536. config will be used instead.
  1537. output_hidden_states (`bool`, *optional*):
  1538. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  1539. more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
  1540. used instead.
  1541. return_dict (`bool`, *optional*):
  1542. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
  1543. eager mode, in graph mode the value will always be set to True.
  1544. training (`bool`, *optional*, defaults to `False``):
  1545. Whether or not to use the model in training mode (some modules like dropout modules have different
  1546. behaviors between training and evaluation).
  1547. """
  1548. class TFGroupViTTextModel(TFGroupViTPreTrainedModel):
  1549. config_class = GroupViTTextConfig
  1550. main_input_name = "input_ids"
  1551. def __init__(self, config: GroupViTTextConfig, *inputs, **kwargs):
  1552. super().__init__(config, *inputs, **kwargs)
  1553. self.groupvit = TFGroupViTTextMainLayer(config, name="groupvit")
  1554. @unpack_inputs
  1555. @add_start_docstrings_to_model_forward(GROUPVIT_TEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  1556. @replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=GroupViTTextConfig)
  1557. def call(
  1558. self,
  1559. input_ids: TFModelInputType | None = None,
  1560. attention_mask: np.ndarray | tf.Tensor | None = None,
  1561. position_ids: np.ndarray | tf.Tensor | None = None,
  1562. output_attentions: Optional[bool] = None,
  1563. output_hidden_states: Optional[bool] = None,
  1564. return_dict: Optional[bool] = None,
  1565. training: bool = False,
  1566. ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
  1567. r"""
  1568. Returns:
  1569. Examples:
  1570. ```python
  1571. >>> from transformers import CLIPTokenizer, TFGroupViTTextModel
  1572. >>> tokenizer = CLIPTokenizer.from_pretrained("nvidia/groupvit-gcc-yfcc")
  1573. >>> model = TFGroupViTTextModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
  1574. >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="tf")
  1575. >>> outputs = model(**inputs)
  1576. >>> last_hidden_state = outputs.last_hidden_state
  1577. >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
  1578. ```"""
  1579. outputs = self.groupvit(
  1580. input_ids=input_ids,
  1581. attention_mask=attention_mask,
  1582. position_ids=position_ids,
  1583. output_attentions=output_attentions,
  1584. output_hidden_states=output_hidden_states,
  1585. return_dict=return_dict,
  1586. training=training,
  1587. )
  1588. return outputs
  1589. def build(self, input_shape=None):
  1590. if self.built:
  1591. return
  1592. self.built = True
  1593. if getattr(self, "groupvit", None) is not None:
  1594. with tf.name_scope(self.groupvit.name):
  1595. self.groupvit.build(None)
  1596. class TFGroupViTVisionModel(TFGroupViTPreTrainedModel):
  1597. config_class = GroupViTVisionConfig
  1598. main_input_name = "pixel_values"
  1599. def __init__(self, config: GroupViTVisionConfig, *inputs, **kwargs):
  1600. super().__init__(config, *inputs, **kwargs)
  1601. self.groupvit = TFGroupViTVisionMainLayer(config, name="groupvit")
  1602. @unpack_inputs
  1603. @add_start_docstrings_to_model_forward(GROUPVIT_VISION_INPUTS_DOCSTRING)
  1604. @replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=GroupViTVisionConfig)
  1605. def call(
  1606. self,
  1607. pixel_values: TFModelInputType | None = None,
  1608. output_attentions: Optional[bool] = None,
  1609. output_hidden_states: Optional[bool] = None,
  1610. return_dict: Optional[bool] = None,
  1611. training: bool = False,
  1612. ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
  1613. r"""
  1614. Returns:
  1615. Examples:
  1616. ```python
  1617. >>> from PIL import Image
  1618. >>> import requests
  1619. >>> from transformers import AutoProcessor, TFGroupViTVisionModel
  1620. >>> processor = AutoProcessor.from_pretrained("nvidia/groupvit-gcc-yfcc")
  1621. >>> model = TFGroupViTVisionModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
  1622. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1623. >>> image = Image.open(requests.get(url, stream=True).raw)
  1624. >>> inputs = processor(images=image, return_tensors="tf")
  1625. >>> outputs = model(**inputs)
  1626. >>> last_hidden_state = outputs.last_hidden_state
  1627. >>> pooled_output = outputs.pooler_output # pooled CLS states
  1628. ```"""
  1629. outputs = self.groupvit(
  1630. pixel_values=pixel_values,
  1631. output_attentions=output_attentions,
  1632. output_hidden_states=output_hidden_states,
  1633. return_dict=return_dict,
  1634. training=training,
  1635. )
  1636. return outputs
  1637. def build(self, input_shape=None):
  1638. if self.built:
  1639. return
  1640. self.built = True
  1641. if getattr(self, "groupvit", None) is not None:
  1642. with tf.name_scope(self.groupvit.name):
  1643. self.groupvit.build(None)
  1644. @add_start_docstrings(GROUPVIT_START_DOCSTRING)
  1645. class TFGroupViTModel(TFGroupViTPreTrainedModel):
  1646. config_class = GroupViTConfig
  1647. def __init__(self, config: GroupViTConfig, *inputs, **kwargs):
  1648. super().__init__(config, *inputs, **kwargs)
  1649. self.groupvit = TFGroupViTMainLayer(config, name="groupvit")
  1650. @unpack_inputs
  1651. @add_start_docstrings_to_model_forward(GROUPVIT_TEXT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  1652. def get_text_features(
  1653. self,
  1654. input_ids: TFModelInputType | None = None,
  1655. attention_mask: np.ndarray | tf.Tensor | None = None,
  1656. position_ids: np.ndarray | tf.Tensor | None = None,
  1657. output_attentions: Optional[bool] = None,
  1658. output_hidden_states: Optional[bool] = None,
  1659. return_dict: Optional[bool] = None,
  1660. training: bool = False,
  1661. ) -> tf.Tensor:
  1662. r"""
  1663. Returns:
  1664. text_features (`tf.Tensor` of shape `(batch_size, output_dim`): The text embeddings obtained by applying
  1665. the projection layer to the pooled output of [`TFGroupViTTextModel`].
  1666. Examples:
  1667. ```python
  1668. >>> from transformers import CLIPTokenizer, TFGroupViTModel
  1669. >>> model = TFGroupViTModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
  1670. >>> tokenizer = CLIPTokenizer.from_pretrained("nvidia/groupvit-gcc-yfcc")
  1671. >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="tf")
  1672. >>> text_features = model.get_text_features(**inputs)
  1673. ```"""
  1674. text_features = self.groupvit.get_text_features(
  1675. input_ids=input_ids,
  1676. attention_mask=attention_mask,
  1677. position_ids=position_ids,
  1678. output_attentions=output_attentions,
  1679. output_hidden_states=output_hidden_states,
  1680. return_dict=return_dict,
  1681. training=training,
  1682. )
  1683. return text_features
  1684. @unpack_inputs
  1685. @add_start_docstrings_to_model_forward(GROUPVIT_VISION_INPUTS_DOCSTRING)
  1686. def get_image_features(
  1687. self,
  1688. pixel_values: TFModelInputType | None = None,
  1689. output_attentions: Optional[bool] = None,
  1690. output_hidden_states: Optional[bool] = None,
  1691. return_dict: Optional[bool] = None,
  1692. training: bool = False,
  1693. ) -> tf.Tensor:
  1694. r"""
  1695. Returns:
  1696. image_features (`tf.Tensor` of shape `(batch_size, output_dim`): The image embeddings obtained by applying
  1697. the projection layer to the pooled output of [`TFGroupViTVisionModel`].
  1698. Examples:
  1699. ```python
  1700. >>> from PIL import Image
  1701. >>> import requests
  1702. >>> from transformers import AutoProcessor, TFGroupViTModel
  1703. >>> model = TFGroupViTModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
  1704. >>> processor = AutoProcessor.from_pretrained("nvidia/groupvit-gcc-yfcc")
  1705. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1706. >>> image = Image.open(requests.get(url, stream=True).raw)
  1707. >>> inputs = processor(images=image, return_tensors="tf")
  1708. >>> image_features = model.get_image_features(**inputs)
  1709. ```"""
  1710. image_features = self.groupvit.get_image_features(
  1711. pixel_values=pixel_values,
  1712. output_attentions=output_attentions,
  1713. output_hidden_states=output_hidden_states,
  1714. return_dict=return_dict,
  1715. training=training,
  1716. )
  1717. return image_features
  1718. @unpack_inputs
  1719. @add_start_docstrings_to_model_forward(GROUPVIT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  1720. @replace_return_docstrings(output_type=TFGroupViTModelOutput, config_class=GroupViTConfig)
  1721. def call(
  1722. self,
  1723. input_ids: TFModelInputType | None = None,
  1724. pixel_values: TFModelInputType | None = None,
  1725. attention_mask: np.ndarray | tf.Tensor | None = None,
  1726. position_ids: np.ndarray | tf.Tensor | None = None,
  1727. return_loss: Optional[bool] = None,
  1728. output_attentions: Optional[bool] = None,
  1729. output_hidden_states: Optional[bool] = None,
  1730. output_segmentation: Optional[bool] = None,
  1731. return_dict: Optional[bool] = None,
  1732. training: bool = False,
  1733. ) -> Union[TFGroupViTModelOutput, Tuple[tf.Tensor]]:
  1734. r"""
  1735. Returns:
  1736. Examples:
  1737. ```python
  1738. >>> from PIL import Image
  1739. >>> import requests
  1740. >>> from transformers import AutoProcessor, TFGroupViTModel
  1741. >>> import tensorflow as tf
  1742. >>> model = TFGroupViTModel.from_pretrained("nvidia/groupvit-gcc-yfcc")
  1743. >>> processor = AutoProcessor.from_pretrained("nvidia/groupvit-gcc-yfcc")
  1744. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1745. >>> image = Image.open(requests.get(url, stream=True).raw)
  1746. >>> inputs = processor(
  1747. ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="tf", padding=True
  1748. ... )
  1749. >>> outputs = model(**inputs)
  1750. >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
  1751. >>> probs = tf.math.softmax(logits_per_image, axis=1) # we can take the softmax to get the label probabilities
  1752. ```"""
  1753. outputs = self.groupvit(
  1754. input_ids=input_ids,
  1755. pixel_values=pixel_values,
  1756. attention_mask=attention_mask,
  1757. position_ids=position_ids,
  1758. return_loss=return_loss,
  1759. output_attentions=output_attentions,
  1760. output_hidden_states=output_hidden_states,
  1761. output_segmentation=output_segmentation,
  1762. return_dict=return_dict,
  1763. training=training,
  1764. )
  1765. return outputs
  1766. def serving_output(self, output: TFGroupViTModelOutput) -> TFGroupViTModelOutput:
  1767. # TODO: As is this currently fails with saved_model=True, because
  1768. # TensorFlow cannot trace through nested dataclasses. Reference:
  1769. # https://github.com/huggingface/transformers/pull/16886
  1770. return output
  1771. def build(self, input_shape=None):
  1772. if self.built:
  1773. return
  1774. self.built = True
  1775. if getattr(self, "groupvit", None) is not None:
  1776. with tf.name_scope(self.groupvit.name):
  1777. self.groupvit.build(None)