modeling_tf_sam.py 74 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652
  1. # coding=utf-8
  2. # Copyright 2023 The Meta AI Authors and The HuggingFace Team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """
  16. TensorFlow SAM model. This file was mostly generated by auto-translation from the PyTorch original. In the event of a
  17. discrepancy, the original file should be regarded as the 'reference' version.
  18. """
  19. from __future__ import annotations
  20. import collections
  21. from dataclasses import dataclass
  22. from typing import Optional, Tuple, Union
  23. import numpy as np
  24. import tensorflow as tf
  25. from ...activations_tf import ACT2FN
  26. from ...modeling_tf_outputs import TFBaseModelOutput
  27. from ...modeling_tf_utils import TFModelInputType, TFPreTrainedModel, keras, shape_list, unpack_inputs
  28. from ...tf_utils import flatten, functional_layernorm
  29. from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging
  30. from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig
  31. logger = logging.get_logger(__name__)
  32. _CONFIG_FOR_DOC = "SamConfig"
  33. _CHECKPOINT_FOR_DOC = "facebook/sam-vit-huge"
  34. @dataclass
  35. class TFSamVisionEncoderOutput(ModelOutput):
  36. """
  37. Base class for sam vision model's outputs that also contains image embeddings obtained by applying the projection
  38. layer to the pooler_output.
  39. Args:
  40. image_embeds (`tf.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
  41. The image embeddings obtained by applying the projection layer to the pooler_output.
  42. last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
  43. Sequence of hidden-states at the output of the last layer of the model.
  44. hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  45. Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
  46. the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  47. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  48. attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  49. Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  50. sequence_length)`.
  51. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  52. heads.
  53. """
  54. image_embeds: tf.Tensor | None = None
  55. last_hidden_state: tf.Tensor = None
  56. hidden_states: Tuple[tf.Tensor, ...] | None = None
  57. attentions: Tuple[tf.Tensor, ...] | None = None
  58. @dataclass
  59. class TFSamImageSegmentationOutput(ModelOutput):
  60. """
  61. Base class for Segment-Anything model's output
  62. Args:
  63. iou_scores (`tf.Tensor` of shape `(batch_size, num_masks)`):
  64. The iou scores of the predicted masks.
  65. pred_masks (`tf.Tensor` of shape `(batch_size, num_masks, height, width)`):
  66. The predicted low resolutions masks. Needs to be post-processed by the processor
  67. vision_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  68. Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
  69. the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  70. Hidden-states of the vision model at the output of each layer plus the optional initial embedding outputs.
  71. vision_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  72. Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  73. sequence_length)`.
  74. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  75. heads.
  76. mask_decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  77. Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  78. sequence_length)`.
  79. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  80. heads.
  81. """
  82. iou_scores: tf.Tensor = None
  83. pred_masks: tf.Tensor = None
  84. vision_hidden_states: Tuple[tf.Tensor, ...] | None = None
  85. vision_attentions: Tuple[tf.Tensor, ...] | None = None
  86. mask_decoder_attentions: Tuple[tf.Tensor, ...] | None = None
  87. class TFSamPatchEmbeddings(keras.layers.Layer):
  88. """
  89. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  90. `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
  91. Transformer.
  92. """
  93. def __init__(self, config, **kwargs):
  94. super().__init__(**kwargs)
  95. image_size, patch_size = config.image_size, config.patch_size
  96. num_channels, hidden_size = config.num_channels, config.hidden_size
  97. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  98. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  99. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  100. self.image_size = image_size
  101. self.patch_size = patch_size
  102. self.num_channels = num_channels
  103. self.num_patches = num_patches
  104. self.projection = keras.layers.Conv2D(
  105. hidden_size, kernel_size=patch_size, strides=patch_size, name="projection"
  106. )
  107. def call(self, pixel_values):
  108. batch_size, num_channels, height, width = shape_list(pixel_values)
  109. if num_channels != self.num_channels:
  110. raise ValueError(
  111. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  112. )
  113. if height != self.image_size[0] or width != self.image_size[1]:
  114. raise ValueError(
  115. f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
  116. )
  117. embeddings = self.projection(tf.transpose(pixel_values, perm=[0, 2, 3, 1]))
  118. return embeddings
  119. def build(self, input_shape=None):
  120. if self.built:
  121. return
  122. self.built = True
  123. if getattr(self, "projection", None) is not None:
  124. with tf.name_scope(self.projection.name):
  125. self.projection.build([None, None, None, self.num_channels])
  126. class TFSamMLPBlock(keras.layers.Layer):
  127. def __init__(self, config, **kwargs):
  128. super().__init__(**kwargs)
  129. self.lin1 = keras.layers.Dense(config.mlp_dim, name="lin1")
  130. self.lin2 = keras.layers.Dense(config.hidden_size, name="lin2")
  131. self.act = ACT2FN[config.hidden_act]
  132. self.config = config
  133. def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
  134. hidden_states = self.lin1(hidden_states)
  135. hidden_states = self.act(hidden_states)
  136. hidden_states = self.lin2(hidden_states)
  137. return hidden_states
  138. def build(self, input_shape=None):
  139. if self.built:
  140. return
  141. self.built = True
  142. if getattr(self, "lin1", None) is not None:
  143. with tf.name_scope(self.lin1.name):
  144. self.lin1.build([None, None, self.config.hidden_size])
  145. if getattr(self, "lin2", None) is not None:
  146. with tf.name_scope(self.lin2.name):
  147. self.lin2.build([None, None, self.config.mlp_dim])
  148. class TFSamLayerNorm(keras.layers.Layer):
  149. r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
  150. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
  151. width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
  152. """
  153. def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", **kwargs):
  154. super().__init__(**kwargs)
  155. self.eps = eps
  156. self.data_format = data_format
  157. self.normalized_shape = normalized_shape
  158. if self.data_format not in ["channels_last", "channels_first"]:
  159. raise NotImplementedError(f"Unsupported data format: {self.data_format}")
  160. def build(self, input_shape):
  161. self.weight = self.add_weight(shape=self.normalized_shape, initializer="ones", name="weight")
  162. self.bias = self.add_weight(shape=self.normalized_shape, initializer="zeros", name="bias")
  163. super().build(input_shape)
  164. def call(self, x: tf.Tensor) -> tf.Tensor:
  165. if self.data_format == "channels_last":
  166. x = functional_layernorm(x, weight=self.weight, bias=self.bias, epsilon=self.eps, axis=-1)
  167. elif self.data_format == "channels_first":
  168. x = functional_layernorm(x, weight=self.weight, bias=self.bias, epsilon=self.eps, axis=1)
  169. return x
  170. class TFSamAttention(keras.layers.Layer):
  171. """
  172. SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
  173. values.
  174. """
  175. def __init__(self, config, downsample_rate=None, **kwargs):
  176. super().__init__(**kwargs)
  177. self.hidden_size = config.hidden_size
  178. downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
  179. self.internal_dim = config.hidden_size // downsample_rate
  180. self.num_attention_heads = config.num_attention_heads
  181. if self.internal_dim % config.num_attention_heads != 0:
  182. raise ValueError("num_attention_heads must divide hidden_size.")
  183. self.q_proj = keras.layers.Dense(self.internal_dim, name="q_proj")
  184. self.k_proj = keras.layers.Dense(self.internal_dim, name="k_proj")
  185. self.v_proj = keras.layers.Dense(self.internal_dim, name="v_proj")
  186. self.out_proj = keras.layers.Dense(self.hidden_size, name="out_proj")
  187. def _separate_heads(self, hidden_states: tf.Tensor, num_attention_heads: int) -> tf.Tensor:
  188. batch, point_batch_size, n_tokens, channel = shape_list(hidden_states)
  189. c_per_head = channel // num_attention_heads
  190. hidden_states = tf.reshape(
  191. hidden_states, (batch * point_batch_size, n_tokens, num_attention_heads, c_per_head)
  192. )
  193. return tf.transpose(hidden_states, perm=[0, 2, 1, 3])
  194. def _recombine_heads(self, hidden_states: tf.Tensor, point_batch_size: int) -> tf.Tensor:
  195. batch, n_heads, n_tokens, c_per_head = shape_list(hidden_states)
  196. hidden_states = tf.transpose(hidden_states, perm=[0, 2, 1, 3])
  197. return tf.reshape(
  198. hidden_states,
  199. (batch // tf.reduce_max([1, point_batch_size]), point_batch_size, n_tokens, n_heads * c_per_head),
  200. )
  201. def call(self, query: tf.Tensor, key: tf.Tensor, value: tf.Tensor) -> tf.Tensor:
  202. # Input projections
  203. query = self.q_proj(query)
  204. key = self.k_proj(key)
  205. value = self.v_proj(value)
  206. point_batch_size = shape_list(query)[1]
  207. # Separate into heads
  208. query = self._separate_heads(query, self.num_attention_heads)
  209. key = self._separate_heads(key, self.num_attention_heads)
  210. value = self._separate_heads(value, self.num_attention_heads)
  211. # SamAttention
  212. _, _, _, c_per_head = shape_list(query)
  213. attn = tf.matmul(
  214. query, tf.transpose(key, perm=[0, 1, 3, 2])
  215. ) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens
  216. attn = attn / tf.math.sqrt(float(c_per_head))
  217. attn = tf.nn.softmax(attn, axis=-1)
  218. # Get output
  219. out = tf.matmul(attn, value)
  220. out = self._recombine_heads(out, point_batch_size)
  221. out = self.out_proj(out)
  222. return out
  223. def build(self, input_shape=None):
  224. if self.built:
  225. return
  226. self.built = True
  227. if getattr(self, "q_proj", None) is not None:
  228. with tf.name_scope(self.q_proj.name):
  229. self.q_proj.build([None, None, self.hidden_size])
  230. if getattr(self, "k_proj", None) is not None:
  231. with tf.name_scope(self.k_proj.name):
  232. self.k_proj.build([None, None, self.hidden_size])
  233. if getattr(self, "v_proj", None) is not None:
  234. with tf.name_scope(self.v_proj.name):
  235. self.v_proj.build([None, None, self.hidden_size])
  236. if getattr(self, "out_proj", None) is not None:
  237. with tf.name_scope(self.out_proj.name):
  238. self.out_proj.build([None, None, self.internal_dim])
  239. class TFSamTwoWayAttentionBlock(keras.layers.Layer):
  240. def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False, **kwargs):
  241. """
  242. A transformer block with four layers:
  243. (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on
  244. sparse inputs (4) cross attention of dense inputs -> sparse inputs
  245. Arguments:
  246. config (`SamMaskDecoderConfig`):
  247. The configuration file used to instantiate the block
  248. attention_downsample_rate (*optionalk*, int, defaults to 2):
  249. The downsample ratio of the block used to reduce the inner dim of the attention.
  250. skip_first_layer_pe (*optional*, bool, defaults to `False`):
  251. Whether or not to skip the addition of the query_point_embedding on the first layer.
  252. """
  253. super().__init__(**kwargs)
  254. self.hidden_size = config.hidden_size
  255. self.layer_norm_eps = config.layer_norm_eps
  256. self.self_attn = TFSamAttention(config, downsample_rate=1, name="self_attn")
  257. self.layer_norm1 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm1")
  258. self.cross_attn_token_to_image = TFSamAttention(
  259. config, downsample_rate=attention_downsample_rate, name="cross_attn_token_to_image"
  260. )
  261. self.layer_norm2 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm2")
  262. self.mlp = TFSamMLPBlock(config, name="mlp")
  263. self.layer_norm3 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm3")
  264. self.layer_norm4 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm4")
  265. self.cross_attn_image_to_token = TFSamAttention(
  266. config, downsample_rate=attention_downsample_rate, name="cross_attn_image_to_token"
  267. )
  268. self.skip_first_layer_pe = skip_first_layer_pe
  269. def call(
  270. self,
  271. queries: tf.Tensor,
  272. keys: tf.Tensor,
  273. query_point_embedding: tf.Tensor,
  274. key_point_embedding: tf.Tensor,
  275. output_attentions: bool = False,
  276. ):
  277. # Self attention block
  278. if self.skip_first_layer_pe:
  279. queries = self.self_attn(query=queries, key=queries, value=queries)
  280. else:
  281. query = queries + query_point_embedding
  282. attn_out = self.self_attn(query=query, key=query, value=queries)
  283. queries = queries + attn_out
  284. queries = self.layer_norm1(queries)
  285. # Cross attention block, tokens attending to image embedding
  286. query = queries + query_point_embedding
  287. key = keys + key_point_embedding
  288. attn_out = self.cross_attn_token_to_image(query=query, key=key, value=keys)
  289. queries = queries + attn_out
  290. queries = self.layer_norm2(queries)
  291. # MLP block
  292. mlp_out = self.mlp(queries)
  293. queries = queries + mlp_out
  294. queries = self.layer_norm3(queries)
  295. # Cross attention block, image embedding attending to tokens
  296. query = queries + query_point_embedding
  297. key = keys + key_point_embedding
  298. attn_out = self.cross_attn_image_to_token(query=key, key=query, value=queries)
  299. keys = keys + attn_out
  300. keys = self.layer_norm4(keys)
  301. outputs = (queries, keys)
  302. if output_attentions:
  303. outputs = outputs + (attn_out,)
  304. else:
  305. outputs = outputs + (None,)
  306. return outputs
  307. def build(self, input_shape=None):
  308. if self.built:
  309. return
  310. self.built = True
  311. if getattr(self, "self_attn", None) is not None:
  312. with tf.name_scope(self.self_attn.name):
  313. self.self_attn.build(None)
  314. if getattr(self, "layer_norm1", None) is not None:
  315. with tf.name_scope(self.layer_norm1.name):
  316. self.layer_norm1.build([None, None, None, self.hidden_size])
  317. if getattr(self, "cross_attn_token_to_image", None) is not None:
  318. with tf.name_scope(self.cross_attn_token_to_image.name):
  319. self.cross_attn_token_to_image.build(None)
  320. if getattr(self, "layer_norm2", None) is not None:
  321. with tf.name_scope(self.layer_norm2.name):
  322. self.layer_norm2.build([None, None, None, self.hidden_size])
  323. if getattr(self, "mlp", None) is not None:
  324. with tf.name_scope(self.mlp.name):
  325. self.mlp.build(None)
  326. if getattr(self, "layer_norm3", None) is not None:
  327. with tf.name_scope(self.layer_norm3.name):
  328. self.layer_norm3.build([None, None, None, self.hidden_size])
  329. if getattr(self, "layer_norm4", None) is not None:
  330. with tf.name_scope(self.layer_norm4.name):
  331. self.layer_norm4.build([None, None, None, self.hidden_size])
  332. if getattr(self, "cross_attn_image_to_token", None) is not None:
  333. with tf.name_scope(self.cross_attn_image_to_token.name):
  334. self.cross_attn_image_to_token.build(None)
  335. class TFSamTwoWayTransformer(keras.layers.Layer):
  336. def __init__(self, config: SamMaskDecoderConfig, **kwargs):
  337. super().__init__(**kwargs)
  338. self.config = config
  339. self.num_hidden_layers = config.num_hidden_layers
  340. self.layers = []
  341. for i in range(self.num_hidden_layers):
  342. self.layers.append(TFSamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0), name=f"layers_._{i}"))
  343. self.final_attn_token_to_image = TFSamAttention(config, name="final_attn_token_to_image")
  344. self.layer_norm_final_attn = keras.layers.LayerNormalization(
  345. epsilon=config.layer_norm_eps, name="layer_norm_final_attn"
  346. )
  347. def call(
  348. self,
  349. point_embeddings: tf.Tensor,
  350. image_embeddings: tf.Tensor,
  351. image_positional_embeddings: tf.Tensor,
  352. output_attentions: Optional[bool] = None,
  353. output_hidden_states: Optional[bool] = None,
  354. return_dict: Optional[bool] = None,
  355. ) -> Union[Tuple, TFBaseModelOutput]:
  356. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  357. output_hidden_states = (
  358. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  359. )
  360. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  361. all_attentions = ()
  362. if image_embeddings is None:
  363. raise ValueError("You have to specify an image_embedding")
  364. image_embeddings = tf.transpose(flatten(image_embeddings, 2), perm=(0, 2, 1))[:, None]
  365. image_positional_embeddings = tf.transpose(flatten(image_positional_embeddings, 2), (0, 2, 1))[:, None]
  366. # Prepare queries
  367. queries = point_embeddings
  368. keys = image_embeddings
  369. # Apply transformer blocks and final layernorm
  370. for layer in self.layers:
  371. queries, keys, attention_outputs = layer(
  372. queries=queries,
  373. keys=keys,
  374. query_point_embedding=point_embeddings,
  375. key_point_embedding=image_positional_embeddings,
  376. output_attentions=output_attentions,
  377. )
  378. if output_attentions:
  379. all_attentions = all_attentions + (attention_outputs,)
  380. # Apply the final attenion layer from the points to the image
  381. query = queries + point_embeddings
  382. key = keys + image_positional_embeddings
  383. attn_out = self.final_attn_token_to_image(query=query, key=key, value=keys)
  384. queries = queries + attn_out
  385. queries = self.layer_norm_final_attn(queries)
  386. return queries, keys, all_attentions
  387. def build(self, input_shape=None):
  388. if self.built:
  389. return
  390. self.built = True
  391. if getattr(self, "final_attn_token_to_image", None) is not None:
  392. with tf.name_scope(self.final_attn_token_to_image.name):
  393. self.final_attn_token_to_image.build(None)
  394. if getattr(self, "layer_norm_final_attn", None) is not None:
  395. with tf.name_scope(self.layer_norm_final_attn.name):
  396. self.layer_norm_final_attn.build([None, None, None, self.config.hidden_size])
  397. for layer in self.layers:
  398. with tf.name_scope(layer.name):
  399. layer.build(None)
  400. class TFSamFeedForward(keras.layers.Layer):
  401. def __init__(
  402. self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False, **kwargs
  403. ):
  404. super().__init__(**kwargs)
  405. self.num_layers = num_layers
  406. self.activation = keras.layers.ReLU()
  407. self.proj_in = keras.layers.Dense(hidden_dim, input_shape=(input_dim,), name="proj_in")
  408. self.proj_out = keras.layers.Dense(output_dim, input_shape=(hidden_dim,), name="proj_out")
  409. self.layers = [
  410. keras.layers.Dense(hidden_dim, input_shape=(hidden_dim,), name=f"layers_._{i}")
  411. for i in range(num_layers - 2)
  412. ]
  413. self.sigmoid_output = sigmoid_output
  414. self.hidden_dim = hidden_dim
  415. self.input_dim = input_dim
  416. def call(self, hidden_states):
  417. hidden_states = self.proj_in(hidden_states)
  418. hidden_states = self.activation(hidden_states)
  419. for layer in self.layers:
  420. hidden_states = self.activation(layer(hidden_states))
  421. hidden_states = self.proj_out(hidden_states)
  422. if self.sigmoid_output:
  423. hidden_states = tf.sigmoid(hidden_states)
  424. return hidden_states
  425. def build(self, input_shape=None):
  426. if self.built:
  427. return
  428. self.built = True
  429. if getattr(self, "proj_in", None) is not None:
  430. with tf.name_scope(self.proj_in.name):
  431. self.proj_in.build([None, None, self.input_dim])
  432. if getattr(self, "proj_out", None) is not None:
  433. with tf.name_scope(self.proj_out.name):
  434. self.proj_out.build([None, None, self.hidden_dim])
  435. if getattr(self, "layers", None) is not None:
  436. for layer in self.layers:
  437. with tf.name_scope(layer.name):
  438. layer.build([None, None, self.hidden_dim])
  439. class TFSamMaskDecoder(keras.layers.Layer):
  440. def __init__(self, config: SamMaskDecoderConfig, **kwargs):
  441. super().__init__(**kwargs)
  442. self.hidden_size = config.hidden_size
  443. self.num_multimask_outputs = config.num_multimask_outputs
  444. self.num_mask_tokens = config.num_multimask_outputs + 1
  445. self.transformer = TFSamTwoWayTransformer(config, name="transformer")
  446. self.upscale_conv1 = keras.layers.Conv2DTranspose(
  447. self.hidden_size // 4, kernel_size=2, strides=2, name="upscale_conv1", data_format="channels_first"
  448. )
  449. self.upscale_conv2 = keras.layers.Conv2DTranspose(
  450. self.hidden_size // 8, kernel_size=2, strides=2, name="upscale_conv2", data_format="channels_first"
  451. )
  452. self.upscale_layer_norm = TFSamLayerNorm(
  453. self.hidden_size // 4, data_format="channels_first", name="upscale_layer_norm"
  454. )
  455. self.activation = tf.nn.gelu
  456. mlps_list = []
  457. for i in range(self.num_mask_tokens):
  458. mlps_list += [
  459. TFSamFeedForward(
  460. self.hidden_size,
  461. self.hidden_size,
  462. self.hidden_size // 8,
  463. 3,
  464. name=f"output_hypernetworks_mlps_._{i}",
  465. )
  466. ]
  467. self.output_hypernetworks_mlps = mlps_list
  468. self.iou_prediction_head = TFSamFeedForward(
  469. self.hidden_size,
  470. config.iou_head_hidden_dim,
  471. self.num_mask_tokens,
  472. config.iou_head_depth,
  473. name="iou_prediction_head",
  474. )
  475. def build(self, input_shape=None):
  476. if self.built:
  477. return
  478. self.built = True
  479. self.iou_token = self.add_weight(shape=(1, self.hidden_size), name="iou_token.weight", trainable=True)
  480. self.mask_tokens = self.add_weight(
  481. shape=(self.num_mask_tokens, self.hidden_size), name="mask_tokens.weight", trainable=True
  482. )
  483. if getattr(self, "transformer", None) is not None:
  484. with tf.name_scope(self.transformer.name):
  485. self.transformer.build(None)
  486. if getattr(self, "upscale_conv1", None) is not None:
  487. with tf.name_scope(self.upscale_conv1.name):
  488. self.upscale_conv1.build([None, self.hidden_size, None, None])
  489. if getattr(self, "upscale_conv2", None) is not None:
  490. with tf.name_scope(self.upscale_conv2.name):
  491. self.upscale_conv2.build([None, self.hidden_size // 4, None, None])
  492. if getattr(self, "upscale_layer_norm", None) is not None:
  493. with tf.name_scope(self.upscale_layer_norm.name):
  494. self.upscale_layer_norm.build(None)
  495. if getattr(self, "iou_prediction_head", None) is not None:
  496. with tf.name_scope(self.iou_prediction_head.name):
  497. self.iou_prediction_head.build(None)
  498. for mlp in self.output_hypernetworks_mlps:
  499. with tf.name_scope(mlp.name):
  500. mlp.build(None)
  501. def call(
  502. self,
  503. image_embeddings: tf.Tensor,
  504. image_positional_embeddings: tf.Tensor,
  505. sparse_prompt_embeddings: tf.Tensor,
  506. dense_prompt_embeddings: tf.Tensor,
  507. multimask_output: bool,
  508. output_attentions: Optional[bool] = None,
  509. ) -> Tuple[tf.Tensor, tf.Tensor]:
  510. batch_size, num_channels, height, width = shape_list(image_embeddings)
  511. point_batch_size = tf.math.maximum(1, tf.shape(sparse_prompt_embeddings)[1])
  512. output_tokens = tf.concat([self.iou_token, self.mask_tokens], axis=0) # Should be (1, 32) + (4, 32) = (5, 32)
  513. output_tokens = tf.tile(
  514. output_tokens[None, None, :], [batch_size, point_batch_size, 1, 1]
  515. ) # Should be (batch_size, point_size, 5, 32)
  516. # Matt: The original Torch code checked that the sum of sparse_prompt_embeddings equalled 0. However, this only
  517. # happens when the sparse prompt embeddings are an empty tensor with shape[1] == 0. I replaced
  518. # it with an explicit shape check to avoid data-dependent control flow which breaks XLA.
  519. if shape_list(sparse_prompt_embeddings)[1] != 0:
  520. tokens = tf.concat((output_tokens, sparse_prompt_embeddings), axis=2)
  521. else:
  522. tokens = output_tokens
  523. point_embeddings = tf.cast(tokens, self.iou_token.dtype)
  524. image_embeddings = image_embeddings + dense_prompt_embeddings
  525. image_embeddings = tf.repeat(image_embeddings, point_batch_size, axis=0)
  526. image_positional_embeddings = tf.repeat(image_positional_embeddings, point_batch_size, axis=0)
  527. point_embedding, image_embeddings, attentions = self.transformer(
  528. point_embeddings=point_embeddings,
  529. image_embeddings=image_embeddings,
  530. image_positional_embeddings=image_positional_embeddings,
  531. output_attentions=output_attentions,
  532. )
  533. iou_token_out = point_embedding[:, :, 0, :]
  534. mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :]
  535. image_embeddings = tf.transpose(image_embeddings, perm=(0, 1, 3, 2))
  536. image_embeddings = tf.reshape(image_embeddings, [batch_size * point_batch_size, num_channels, height, width])
  537. upscaled_embedding = self.upscale_conv1(image_embeddings)
  538. upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
  539. upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding))
  540. hyper_in_list = []
  541. for i in range(self.num_mask_tokens):
  542. current_mlp = self.output_hypernetworks_mlps[i]
  543. hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
  544. hyper_in = tf.stack(hyper_in_list, axis=2)
  545. _, num_channels, height, width = shape_list(upscaled_embedding)
  546. upscaled_embedding = tf.reshape(
  547. upscaled_embedding, [batch_size, point_batch_size, num_channels, height * width]
  548. )
  549. masks = tf.reshape(hyper_in @ upscaled_embedding, [batch_size, point_batch_size, -1, height, width])
  550. iou_pred = self.iou_prediction_head(iou_token_out)
  551. if multimask_output:
  552. mask_slice = slice(1, None)
  553. else:
  554. mask_slice = slice(0, 1)
  555. masks = masks[:, :, mask_slice, :, :]
  556. iou_pred = iou_pred[:, :, mask_slice]
  557. outputs = (masks, iou_pred)
  558. if output_attentions:
  559. outputs = outputs + (attentions,)
  560. else:
  561. outputs = outputs + (None,)
  562. return outputs
  563. class TFSamPositionalEmbedding(keras.layers.Layer):
  564. def __init__(self, config, **kwargs):
  565. super().__init__(**kwargs)
  566. self.scale = config.hidden_size // 2
  567. self.config = config
  568. def build(self, input_shape):
  569. # TODO Matt: What is going on here? Why is a non-trainable weight randomly initialized?
  570. self.positional_embedding = self.add_weight(
  571. name="positional_embedding",
  572. shape=(2, self.config.num_pos_feats),
  573. initializer=keras.initializers.RandomNormal(mean=0.0, stddev=self.scale),
  574. trainable=False,
  575. )
  576. super().build(input_shape)
  577. def call(self, input_coords, input_shape=None):
  578. """Positionally encode points that are normalized to [0,1]."""
  579. coordinates = tf.identity(input_coords)
  580. if input_shape is not None:
  581. coordinates = tf.stack(
  582. [
  583. tf.cast(coordinates[:, :, :, 0], tf.float32) / input_shape[1],
  584. tf.cast(coordinates[:, :, :, 1], tf.float32) / input_shape[0],
  585. ],
  586. axis=-1,
  587. )
  588. # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
  589. coordinates = 2 * coordinates - 1
  590. coordinates = tf.cast(coordinates, self.positional_embedding.dtype)
  591. coordinates = tf.matmul(coordinates, self.positional_embedding)
  592. coordinates = 2 * np.pi * coordinates
  593. # outputs d_1 x ... x d_n x channel shape
  594. return tf.concat([tf.sin(coordinates), tf.cos(coordinates)], axis=-1)
  595. class TFSamMaskEmbedding(keras.layers.Layer):
  596. def __init__(self, config: SamPromptEncoderConfig, **kwargs):
  597. super().__init__(**kwargs)
  598. self.mask_input_channels = config.mask_input_channels // 4
  599. self.activation = ACT2FN[config.hidden_act]
  600. self.conv1 = keras.layers.Conv2D(self.mask_input_channels, kernel_size=2, strides=2, name="conv1")
  601. self.conv2 = keras.layers.Conv2D(config.mask_input_channels, kernel_size=2, strides=2, name="conv2")
  602. self.conv3 = keras.layers.Conv2D(config.hidden_size, kernel_size=1, name="conv3")
  603. self.layer_norm1 = TFSamLayerNorm(self.mask_input_channels, config.layer_norm_eps, name="layer_norm1")
  604. self.layer_norm2 = TFSamLayerNorm(self.mask_input_channels * 4, config.layer_norm_eps, name="layer_norm2")
  605. self.config = config
  606. def call(self, masks):
  607. masks = tf.transpose(masks, perm=(0, 2, 3, 1)) # Convert to channels-last
  608. hidden_states = self.conv1(masks)
  609. hidden_states = self.layer_norm1(hidden_states)
  610. hidden_states = self.activation(hidden_states)
  611. hidden_states = self.conv2(hidden_states)
  612. hidden_states = self.layer_norm2(hidden_states)
  613. hidden_states = self.activation(hidden_states)
  614. dense_embeddings = self.conv3(hidden_states)
  615. dense_embeddings = tf.transpose(dense_embeddings, perm=(0, 3, 1, 2)) # Convert back to channels-first
  616. return dense_embeddings
  617. def build(self, input_shape=None):
  618. # This class needs an explicit build method because it isn't called with the standard dummy inputs
  619. if self.built:
  620. return
  621. self.built = True
  622. with tf.name_scope("conv1"):
  623. self.conv1.build([None, None, None, 1])
  624. with tf.name_scope("conv2"):
  625. self.conv2.build([None, None, None, self.mask_input_channels])
  626. with tf.name_scope("conv3"):
  627. self.conv3.build([None, None, None, self.mask_input_channels * 4])
  628. with tf.name_scope("layer_norm1"):
  629. self.layer_norm1.build([None, None, None, self.mask_input_channels])
  630. with tf.name_scope("layer_norm2"):
  631. self.layer_norm2.build([None, None, None, self.mask_input_channels * 4])
  632. class TFSamPromptEncoder(keras.layers.Layer):
  633. def __init__(self, config: SamPromptEncoderConfig, shared_patch_embedding, **kwargs):
  634. super().__init__(**kwargs)
  635. self.shared_embedding = shared_patch_embedding
  636. self.mask_embed = TFSamMaskEmbedding(config, name="mask_embed")
  637. self.no_mask_embed = None
  638. self.image_embedding_size = (config.image_embedding_size, config.image_embedding_size)
  639. self.input_image_size = config.image_size
  640. self.point_embed = []
  641. self.hidden_size = config.hidden_size
  642. self.not_a_point_embed = None
  643. self.config = config
  644. def build(self, input_shape=None):
  645. self.no_mask_embed = self.add_weight(
  646. name="no_mask_embed.weight",
  647. shape=(1, self.hidden_size),
  648. initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02),
  649. trainable=True,
  650. )
  651. self.point_embed = [
  652. self.add_weight(
  653. name=f"point_embed_._{i}.weight",
  654. shape=(1, self.hidden_size),
  655. initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02),
  656. trainable=True,
  657. )
  658. for i in range(self.config.num_point_embeddings)
  659. ]
  660. self.not_a_point_embed = self.add_weight(
  661. name="not_a_point_embed.weight",
  662. shape=(1, self.hidden_size),
  663. initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02),
  664. trainable=True,
  665. )
  666. with tf.name_scope("mask_embed"):
  667. # We must explicitly build the mask embed because it isn't touched by the standard dummy inputs
  668. self.mask_embed.build(
  669. (None, self.config.mask_input_channels, self.config.image_size, self.config.image_size)
  670. )
  671. if self.built:
  672. return
  673. self.built = True
  674. if getattr(self, "mask_embed", None) is not None:
  675. with tf.name_scope(self.mask_embed.name):
  676. self.mask_embed.build(None)
  677. def _embed_points(self, points: tf.Tensor, labels: tf.Tensor, pad: bool) -> tf.Tensor:
  678. """Embeds point prompts."""
  679. points = points + 0.5 # Shift to center of pixel
  680. if pad:
  681. target_point_shape = (shape_list(points)[0], shape_list(points)[1], 1, shape_list(points)[-1])
  682. target_labels_shape = (shape_list(points)[0], shape_list(points)[1], 1)
  683. padding_point = tf.zeros(target_point_shape, dtype=points.dtype)
  684. padding_label = -tf.ones(target_labels_shape, dtype=labels.dtype)
  685. points = tf.concat([points, padding_point], axis=2)
  686. labels = tf.concat([labels, padding_label], axis=2)
  687. input_shape = (self.input_image_size, self.input_image_size)
  688. point_embedding = self.shared_embedding(points, input_shape)
  689. point_embedding = tf.where(labels[..., None] == -1, self.not_a_point_embed[0], point_embedding)
  690. point_embedding = tf.where(
  691. labels[..., None] != -10,
  692. point_embedding,
  693. tf.zeros_like(point_embedding),
  694. )
  695. point_embedding = tf.where(
  696. (labels == 0)[:, :, :, None], point_embedding + self.point_embed[0], point_embedding
  697. )
  698. point_embedding = tf.where(
  699. (labels == 1)[:, :, :, None], point_embedding + self.point_embed[1], point_embedding
  700. )
  701. return point_embedding
  702. def _embed_boxes(self, boxes: tf.Tensor) -> tf.Tensor:
  703. """Embeds box prompts."""
  704. boxes = boxes + 0.5 # Shift to center of pixel
  705. batch_size, nb_boxes = shape_list(boxes)[:2]
  706. coords = tf.reshape(boxes, (batch_size, nb_boxes, 2, 2))
  707. input_shape = (self.input_image_size, self.input_image_size)
  708. corner_embedding = self.shared_embedding(coords, input_shape)
  709. corner_embedding += tf.where(
  710. tf.range(shape_list(corner_embedding)[2])[None, None, :, None] == 0,
  711. self.point_embed[2][0],
  712. self.point_embed[3][0],
  713. )
  714. return corner_embedding
  715. def call(
  716. self,
  717. batch_size: Optional[int],
  718. input_points: Optional[Tuple[tf.Tensor, tf.Tensor]],
  719. input_labels: tf.Tensor | None,
  720. input_boxes: tf.Tensor | None,
  721. input_masks: tf.Tensor | None,
  722. ) -> Tuple[tf.Tensor, tf.Tensor]:
  723. """
  724. Embeds different types of prompts, returning both sparse and dense embeddings.
  725. Args:
  726. points (`tf.Tensor`, *optional*):
  727. point coordinates and labels to embed.
  728. boxes (`tf.Tensor`, *optional*):
  729. boxes to embed
  730. masks (`tf.Tensor`, *optional*):
  731. masks to embed
  732. """
  733. sparse_embeddings = None
  734. if input_points is not None:
  735. batch_size, point_batch_size = shape_list(input_points)[:2]
  736. if input_labels is None:
  737. raise ValueError("If points are provided, labels must also be provided.")
  738. point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
  739. sparse_embeddings = tf.zeros(
  740. (batch_size, point_batch_size, 0, self.hidden_size), dtype=point_embeddings.dtype
  741. )
  742. sparse_embeddings = tf.concat([sparse_embeddings, point_embeddings], axis=2)
  743. if input_boxes is not None:
  744. batch_size = shape_list(input_boxes)[0]
  745. box_embeddings = self._embed_boxes(input_boxes)
  746. if sparse_embeddings is None:
  747. sparse_embeddings = box_embeddings
  748. else:
  749. sparse_embeddings = tf.concat([sparse_embeddings, box_embeddings], axis=2)
  750. if input_masks is not None:
  751. dense_embeddings = self.mask_embed(input_masks)
  752. else:
  753. dense_embeddings = self.no_mask_embed[0]
  754. dense_embeddings = tf.reshape(dense_embeddings, (1, -1, 1, 1))
  755. dense_embeddings = tf.tile(
  756. dense_embeddings, (batch_size, 1, self.image_embedding_size[0], self.image_embedding_size[1])
  757. )
  758. if sparse_embeddings is None:
  759. sparse_embeddings = tf.zeros((batch_size, 0, 1, self.hidden_size), dtype=dense_embeddings.dtype)
  760. return sparse_embeddings, dense_embeddings
  761. class TFSamVisionAttention(keras.layers.Layer):
  762. """Multi-head Attention block with relative position embeddings."""
  763. def __init__(self, config, window_size, **kwargs):
  764. super().__init__(**kwargs)
  765. input_size = (
  766. (config.image_size // config.patch_size, config.image_size // config.patch_size)
  767. if window_size == 0
  768. else (window_size, window_size)
  769. )
  770. self.input_size = input_size
  771. self.num_attention_heads = config.num_attention_heads
  772. head_dim = config.hidden_size // config.num_attention_heads
  773. self.head_dim = head_dim
  774. self.scale = head_dim**-0.5
  775. self.dropout = config.attention_dropout
  776. self.qkv = keras.layers.Dense(config.hidden_size * 3, use_bias=config.qkv_bias, name="qkv")
  777. self.proj = keras.layers.Dense(config.hidden_size, name="proj")
  778. self.use_rel_pos = config.use_rel_pos
  779. if self.use_rel_pos:
  780. if input_size is None:
  781. raise ValueError("Input size must be provided if using relative positional encoding.")
  782. self.config = config
  783. def build(self, input_shape=None):
  784. if self.input_size is not None:
  785. # initialize relative positional embeddings
  786. self.rel_pos_h = self.add_weight(
  787. shape=(2 * self.input_size[0] - 1, self.head_dim), initializer="zeros", name="rel_pos_h"
  788. )
  789. self.rel_pos_w = self.add_weight(
  790. shape=(2 * self.input_size[1] - 1, self.head_dim), initializer="zeros", name="rel_pos_w"
  791. )
  792. if self.built:
  793. return
  794. self.built = True
  795. if getattr(self, "qkv", None) is not None:
  796. with tf.name_scope(self.qkv.name):
  797. self.qkv.build([None, None, self.config.hidden_size])
  798. if getattr(self, "proj", None) is not None:
  799. with tf.name_scope(self.proj.name):
  800. self.proj.build([None, None, self.config.hidden_size])
  801. def get_rel_pos(self, q_size: int, k_size: int, rel_pos: tf.Tensor) -> tf.Tensor:
  802. """
  803. Get relative positional embeddings according to the relative positions of
  804. query and key sizes.
  805. Args:
  806. q_size (int):
  807. size of the query.
  808. k_size (int):
  809. size of key k.
  810. rel_pos (`tf.Tensor`):
  811. relative position embeddings (L, channel).
  812. Returns:
  813. Extracted positional embeddings according to relative positions.
  814. """
  815. max_rel_dist = int(2 * max(q_size, k_size) - 1)
  816. # Interpolate rel pos if needed.
  817. if rel_pos.shape[0] != max_rel_dist:
  818. # Interpolate rel pos.
  819. rel_pos_resized = tf.image.resize(
  820. tf.reshape(rel_pos, (1, rel_pos.shape[0], -1)),
  821. size=(max_rel_dist, rel_pos.shape[1]),
  822. method="bilinear",
  823. )
  824. rel_pos_resized = tf.reshape(rel_pos_resized, (-1, max_rel_dist))
  825. else:
  826. rel_pos_resized = rel_pos
  827. # Scale the coords with short length if shapes for q and k are different.
  828. q_coords = tf.expand_dims(tf.range(q_size, dtype=tf.float32), 1) * max(k_size / q_size, 1.0)
  829. k_coords = tf.expand_dims(tf.range(k_size, dtype=tf.float32), 0) * max(q_size / k_size, 1.0)
  830. relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
  831. return tf.gather(rel_pos_resized, tf.cast(relative_coords, tf.int32))
  832. def add_decomposed_rel_pos(
  833. self,
  834. attn: tf.Tensor,
  835. query: tf.Tensor,
  836. rel_pos_h: tf.Tensor,
  837. rel_pos_w: tf.Tensor,
  838. q_size: Tuple[int, int],
  839. k_size: Tuple[int, int],
  840. ) -> tf.Tensor:
  841. """
  842. Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
  843. https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
  844. Args:
  845. attn (`tf.Tensor`):
  846. attention map.
  847. query (`tf.Tensor`):
  848. query q in the attention layer with shape (batch_size, query_height * query_width, channel).
  849. rel_pos_h (`tf.Tensor`):
  850. relative position embeddings (Lh, channel) for height axis.
  851. rel_pos_w (`tf.Tensor`):
  852. relative position embeddings (Lw, channel) for width axis.
  853. q_size (tuple):
  854. spatial sequence size of query q with (query_height, query_width).
  855. k_size (tuple):
  856. spatial sequence size of key k with (key_height, key_width).
  857. Returns:
  858. attn (`tf.Tensor`):
  859. attention map with added relative positional embeddings.
  860. """
  861. query_height, query_width = q_size
  862. key_height, key_width = k_size
  863. relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h)
  864. relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w)
  865. batch_size, _, dim = shape_list(query)
  866. reshaped_query = tf.reshape(query, (batch_size, query_height, query_width, dim))
  867. rel_h = tf.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
  868. rel_w = tf.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
  869. attn = tf.reshape(attn, (batch_size, query_height, query_width, key_height, key_width))
  870. attn = attn + tf.expand_dims(rel_h, axis=-1) + tf.expand_dims(rel_w, axis=-2)
  871. attn = tf.reshape(attn, (batch_size, query_height * query_width, key_height * key_width))
  872. return attn
  873. def call(self, hidden_states: tf.Tensor, output_attentions=False, training=False) -> tf.Tensor:
  874. batch_size, height, width, _ = shape_list(hidden_states)
  875. # qkv with shape (3, batch_size, nHead, height * width, channel)
  876. qkv = tf.reshape(self.qkv(hidden_states), (batch_size, height * width, 3, self.num_attention_heads, -1))
  877. qkv = tf.transpose(qkv, perm=(2, 0, 3, 1, 4))
  878. # q, k, v with shape (batch_size * nHead, height * width, channel)
  879. query, key, value = tf.unstack(
  880. tf.reshape(qkv, (3, batch_size * self.num_attention_heads, height * width, -1)), axis=0
  881. )
  882. attn_weights = tf.matmul(query * self.scale, key, transpose_b=True)
  883. if self.use_rel_pos:
  884. attn_weights = self.add_decomposed_rel_pos(
  885. attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
  886. )
  887. attn_weights = tf.nn.softmax(attn_weights, axis=-1)
  888. if training:
  889. attn_probs = tf.nn.dropout(attn_weights, rate=self.dropout)
  890. else:
  891. attn_probs = attn_weights
  892. attn_output = tf.reshape(attn_probs @ value, (batch_size, self.num_attention_heads, height, width, -1))
  893. attn_output = tf.transpose(attn_output, perm=(0, 2, 3, 1, 4))
  894. attn_output = tf.reshape(attn_output, (batch_size, height, width, self.config.hidden_size))
  895. attn_output = self.proj(attn_output)
  896. if output_attentions:
  897. outputs = (attn_output, attn_weights)
  898. else:
  899. outputs = (attn_output, None)
  900. return outputs
  901. class TFSamVisionLayer(keras.layers.Layer):
  902. def __init__(self, config, window_size, **kwargs):
  903. super().__init__(**kwargs)
  904. self.layer_norm1 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm1")
  905. self.attn = TFSamVisionAttention(config, window_size, name="attn")
  906. self.layer_norm2 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2")
  907. self.mlp = TFSamMLPBlock(config, name="mlp")
  908. self.window_size = window_size
  909. self.config = config
  910. def window_partition(self, hidden_states: tf.Tensor, window_size: int) -> Tuple[tf.Tensor, Tuple[int, int]]:
  911. batch_size, height, width, channel = shape_list(hidden_states)
  912. pad_h = (window_size - height % window_size) % window_size
  913. pad_w = (window_size - width % window_size) % window_size
  914. if pad_h > 0 or pad_w > 0:
  915. hidden_states = tf.pad(hidden_states, [[0, 0], [0, pad_h], [0, pad_w], [0, 0]])
  916. pad_height, pad_width = height + pad_h, width + pad_w
  917. hidden_states = tf.reshape(
  918. hidden_states,
  919. [batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel],
  920. )
  921. windows = tf.reshape(
  922. tf.transpose(hidden_states, perm=[0, 1, 3, 2, 4, 5]), [-1, window_size, window_size, channel]
  923. )
  924. return windows, (pad_height, pad_width)
  925. def window_unpartition(
  926. self, windows: tf.Tensor, window_size: int, padding_shape: Tuple[int, int], original_shape: Tuple[int, int]
  927. ) -> tf.Tensor:
  928. pad_height, pad_width = padding_shape
  929. height, width = original_shape
  930. batch_size = shape_list(windows)[0] // (pad_height * pad_width // window_size // window_size)
  931. hidden_states = tf.reshape(
  932. windows, [batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1]
  933. )
  934. hidden_states = tf.reshape(
  935. tf.transpose(hidden_states, perm=[0, 1, 3, 2, 4, 5]), [batch_size, pad_height, pad_width, -1]
  936. )
  937. if pad_height > height or pad_width > width:
  938. hidden_states = hidden_states[:, :height, :width, :]
  939. return hidden_states
  940. def call(
  941. self,
  942. hidden_states: tf.Tensor,
  943. output_attentions: Optional[bool] = False,
  944. training: Optional[bool] = False,
  945. ) -> Tuple[tf.Tensor]:
  946. residual = hidden_states
  947. hidden_states = self.layer_norm1(hidden_states)
  948. if self.window_size > 0:
  949. height, width = hidden_states.shape[1], hidden_states.shape[2]
  950. hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size)
  951. hidden_states, attn_weights = self.attn(
  952. hidden_states=hidden_states,
  953. output_attentions=output_attentions,
  954. training=training,
  955. )
  956. if self.window_size > 0:
  957. hidden_states = self.window_unpartition(hidden_states, self.window_size, padding_shape, (height, width))
  958. hidden_states = residual + hidden_states
  959. layernorm_output = self.layer_norm2(hidden_states)
  960. hidden_states = hidden_states + self.mlp(layernorm_output)
  961. outputs = (hidden_states,)
  962. if output_attentions:
  963. outputs += (attn_weights,)
  964. return outputs
  965. def build(self, input_shape=None):
  966. if self.built:
  967. return
  968. self.built = True
  969. if getattr(self, "layer_norm1", None) is not None:
  970. with tf.name_scope(self.layer_norm1.name):
  971. self.layer_norm1.build([None, None, None, self.config.hidden_size])
  972. if getattr(self, "attn", None) is not None:
  973. with tf.name_scope(self.attn.name):
  974. self.attn.build(None)
  975. if getattr(self, "layer_norm2", None) is not None:
  976. with tf.name_scope(self.layer_norm2.name):
  977. self.layer_norm2.build([None, None, None, self.config.hidden_size])
  978. if getattr(self, "mlp", None) is not None:
  979. with tf.name_scope(self.mlp.name):
  980. self.mlp.build(None)
  981. class TFSamVisionNeck(keras.layers.Layer):
  982. def __init__(self, config: SamVisionConfig, **kwargs):
  983. super().__init__(**kwargs)
  984. self.config = config
  985. self.conv1 = keras.layers.Conv2D(
  986. config.output_channels,
  987. kernel_size=1,
  988. use_bias=False,
  989. name="conv1",
  990. )
  991. self.layer_norm1 = TFSamLayerNorm(config.output_channels, name="layer_norm1")
  992. self.conv2 = keras.layers.Conv2D(
  993. config.output_channels,
  994. kernel_size=3,
  995. padding="same",
  996. use_bias=False,
  997. name="conv2",
  998. )
  999. self.layer_norm2 = TFSamLayerNorm(config.output_channels, name="layer_norm2")
  1000. def call(self, hidden_states):
  1001. hidden_states = self.conv1(hidden_states)
  1002. hidden_states = self.layer_norm1(hidden_states)
  1003. hidden_states = self.conv2(hidden_states)
  1004. hidden_states = self.layer_norm2(hidden_states)
  1005. hidden_states = tf.transpose(hidden_states, perm=[0, 3, 1, 2])
  1006. return hidden_states
  1007. def build(self, input_shape=None):
  1008. if self.built:
  1009. return
  1010. self.built = True
  1011. if getattr(self, "conv1", None) is not None:
  1012. with tf.name_scope(self.conv1.name):
  1013. self.conv1.build([None, None, None, self.config.hidden_size])
  1014. if getattr(self, "layer_norm1", None) is not None:
  1015. with tf.name_scope(self.layer_norm1.name):
  1016. self.layer_norm1.build(None)
  1017. if getattr(self, "conv2", None) is not None:
  1018. with tf.name_scope(self.conv2.name):
  1019. self.conv2.build([None, None, None, self.config.output_channels])
  1020. if getattr(self, "layer_norm2", None) is not None:
  1021. with tf.name_scope(self.layer_norm2.name):
  1022. self.layer_norm2.build(None)
  1023. class TFSamVisionEncoder(keras.layers.Layer):
  1024. def __init__(self, config: SamVisionConfig, **kwargs):
  1025. super().__init__(**kwargs)
  1026. self.config = config
  1027. self.image_size = config.image_size
  1028. self.patch_embed = TFSamPatchEmbeddings(config, name="patch_embed")
  1029. self.pos_embed = None
  1030. self.layers = []
  1031. for i in range(config.num_hidden_layers):
  1032. layer = TFSamVisionLayer(
  1033. config,
  1034. window_size=config.window_size if i not in config.global_attn_indexes else 0,
  1035. name=f"layers_._{i}",
  1036. )
  1037. self.layers.append(layer)
  1038. self.neck = TFSamVisionNeck(config, name="neck")
  1039. def build(self, input_shape=None):
  1040. if self.built:
  1041. return
  1042. self.built = True
  1043. if self.config.use_abs_pos:
  1044. # Initialize absolute positional embedding with pretrain image size.
  1045. self.pos_embed = self.add_weight(
  1046. shape=[
  1047. 1,
  1048. self.config.image_size // self.config.patch_size,
  1049. self.config.image_size // self.config.patch_size,
  1050. self.config.hidden_size,
  1051. ],
  1052. initializer="zeros",
  1053. trainable=True,
  1054. name="pos_embed",
  1055. )
  1056. if getattr(self, "patch_embed", None) is not None:
  1057. with tf.name_scope(self.patch_embed.name):
  1058. self.patch_embed.build(None)
  1059. if getattr(self, "neck", None) is not None:
  1060. with tf.name_scope(self.neck.name):
  1061. self.neck.build(None)
  1062. for layer in self.layers:
  1063. with tf.name_scope(layer.name):
  1064. layer.build(None)
  1065. def get_input_embeddings(self):
  1066. return self.patch_embed
  1067. def call(
  1068. self,
  1069. pixel_values: tf.Tensor | None = None,
  1070. output_attentions: Optional[bool] = None,
  1071. output_hidden_states: Optional[bool] = None,
  1072. return_dict: Optional[bool] = None,
  1073. training: Optional[bool] = False,
  1074. ) -> Union[Tuple, TFSamVisionEncoderOutput]:
  1075. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1076. output_hidden_states = (
  1077. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1078. )
  1079. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1080. if pixel_values is None:
  1081. raise ValueError("You have to specify pixel_values")
  1082. hidden_states = self.patch_embed(pixel_values)
  1083. if self.pos_embed is not None:
  1084. hidden_states = hidden_states + self.pos_embed
  1085. all_hidden_states = () if output_hidden_states else None
  1086. all_self_attentions = () if output_attentions else None
  1087. for i, layer_module in enumerate(self.layers):
  1088. if output_hidden_states:
  1089. all_hidden_states = all_hidden_states + (hidden_states,)
  1090. layer_outputs = layer_module(hidden_states, output_attentions=output_attentions, training=training)
  1091. hidden_states = layer_outputs[0]
  1092. if output_attentions:
  1093. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  1094. if output_hidden_states:
  1095. all_hidden_states = all_hidden_states + (hidden_states,)
  1096. hidden_states = self.neck(hidden_states)
  1097. if not return_dict:
  1098. outputs = (hidden_states,)
  1099. if output_hidden_states:
  1100. outputs = outputs + (all_hidden_states,)
  1101. if output_attentions:
  1102. outputs = outputs + (all_self_attentions,)
  1103. return outputs
  1104. return TFSamVisionEncoderOutput(
  1105. last_hidden_state=hidden_states,
  1106. hidden_states=all_hidden_states,
  1107. attentions=all_self_attentions,
  1108. )
  1109. class TFSamPreTrainedModel(TFPreTrainedModel):
  1110. config_class = SamConfig
  1111. base_model_prefix = "sam"
  1112. main_input_name = "pixel_values"
  1113. SAM_START_DOCSTRING = r"""
  1114. This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
  1115. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  1116. etc.)
  1117. This model is also a TensorFlow [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model)
  1118. subclass. Use it as a regular TensorFlow Model and refer to the TensorFlow documentation for all matter related to
  1119. general usage and behavior.
  1120. Parameters:
  1121. config ([`SamConfig`]): Model configuration class with all the parameters of the model.
  1122. Initializing with a config file does not load the weights associated with the model, only the
  1123. configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
  1124. """
  1125. SAM_INPUTS_DOCSTRING = r"""
  1126. Args:
  1127. pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
  1128. Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for
  1129. details.
  1130. input_points (`tf.Tensor` of shape `(batch_size, num_points, 2)`):
  1131. Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much
  1132. better results. The points can be obtained by passing a list of list of list to the processor that will
  1133. create corresponding `tf` tensors of dimension 4. The first dimension is the image batch size, the second
  1134. dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict per
  1135. input point), the third dimension is the number of points per segmentation mask (it is possible to pass
  1136. multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal)
  1137. coordinates of the point. If a different number of points is passed either for each image, or for each
  1138. mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the
  1139. computation of the embedding will be skipped for these points using the labels.
  1140. input_labels (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points)`):
  1141. Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the
  1142. official implementation, there are 3 types of labels
  1143. - `1`: the point is a point that contains the object of interest
  1144. - `0`: the point is a point that does not contain the object of interest
  1145. - `-1`: the point corresponds to the background
  1146. We added the label:
  1147. - `-10`: the point is a padding point, thus should be ignored by the prompt encoder
  1148. The padding labels should be automatically done by the processor.
  1149. input_boxes (`tf.Tensor` of shape `(batch_size, num_boxes, 4)`):
  1150. Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to
  1151. much better generated masks. The boxes can be obtained by passing a list of list of list to the processor,
  1152. that will generate a `tf` tensor, with each dimension corresponding respectively to the image batch size,
  1153. the number of boxes per image and the coordinates of the top left and botton right point of the box. In the
  1154. order (`x1`, `y1`, `x2`, `y2`):
  1155. - `x1`: the x coordinate of the top left point of the input box
  1156. - `y1`: the y coordinate of the top left point of the input box
  1157. - `x2`: the x coordinate of the bottom right point of the input box
  1158. - `y2`: the y coordinate of the bottom right point of the input box
  1159. input_masks (`tf.Tensor` of shape `(batch_size, image_size, image_size)`):
  1160. SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to
  1161. generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be
  1162. manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`).
  1163. image_embeddings (`tf.Tensor` of shape `(batch_size, output_channels, window_size, window_size)`):
  1164. Image embeddings, this is used by the mask decder to generate masks and iou scores. For more memory
  1165. efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings`
  1166. method, and then feed them to the `call` method instead of feeding the `pixel_values`.
  1167. multimask_output (`bool`, *optional*):
  1168. In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
  1169. bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
  1170. "best" mask, by specifying `multimask_output=False`.
  1171. output_attentions (`bool`, *optional*):
  1172. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  1173. tensors for more detail.
  1174. output_hidden_states (`bool`, *optional*):
  1175. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  1176. more detail.
  1177. return_dict (`bool`, *optional*):
  1178. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  1179. """
  1180. @add_start_docstrings(
  1181. "Segment Anything Model (SAM) for generating segmentation masks, given an input image and ",
  1182. " optional 2D location and bounding boxes.",
  1183. SAM_START_DOCSTRING,
  1184. )
  1185. class TFSamModel(TFSamPreTrainedModel):
  1186. _keys_to_ignore_on_load_missing = [r"prompt_encoder.shared_embedding.positional_embedding"]
  1187. def __init__(self, config, **kwargs):
  1188. super().__init__(config, **kwargs)
  1189. self.shared_image_embedding = TFSamPositionalEmbedding(config.vision_config, name="shared_image_embedding")
  1190. self.vision_encoder = TFSamVisionEncoder(config.vision_config, name="vision_encoder")
  1191. self.prompt_encoder = TFSamPromptEncoder(
  1192. config.prompt_encoder_config, self.shared_image_embedding, name="prompt_encoder"
  1193. )
  1194. self.mask_decoder = TFSamMaskDecoder(config.mask_decoder_config, name="mask_decoder")
  1195. self.config = config
  1196. def get_input_embeddings(self):
  1197. return self.vision_encoder.get_input_embeddings()
  1198. def get_image_wide_positional_embeddings(self):
  1199. size = self.config.prompt_encoder_config.image_embedding_size
  1200. grid = tf.ones((size, size))
  1201. y_embed = tf.math.cumsum(grid, axis=0) - 0.5
  1202. x_embed = tf.math.cumsum(grid, axis=1) - 0.5
  1203. y_embed = y_embed / size
  1204. x_embed = x_embed / size
  1205. positional_embedding = self.shared_image_embedding(tf.stack([x_embed, y_embed], axis=-1))
  1206. return tf.expand_dims(tf.transpose(positional_embedding, perm=[2, 0, 1]), axis=0) # channel x height x width
  1207. def get_image_embeddings(
  1208. self,
  1209. pixel_values,
  1210. output_attentions: Optional[bool] = None,
  1211. output_hidden_states: Optional[bool] = None,
  1212. return_dict: Optional[bool] = None,
  1213. ):
  1214. r"""
  1215. Returns the image embeddings by passing the pixel values through the vision encoder.
  1216. Args:
  1217. pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
  1218. Input pixel values
  1219. output_attentions (`bool`, *optional*):
  1220. Whether or not to return the attentions tensors of all attention layers.
  1221. output_hidden_states (`bool`, *optional*):
  1222. Whether or not to return the hidden states of all layers.
  1223. return_dict (`bool`, *optional*):
  1224. Whether or not to return a [`~utils.TFModelOutput`] instead of a plain tuple.
  1225. """
  1226. vision_output = self.vision_encoder(
  1227. pixel_values,
  1228. output_attentions=output_attentions,
  1229. output_hidden_states=output_hidden_states,
  1230. return_dict=return_dict,
  1231. )
  1232. image_embeddings = vision_output[0]
  1233. return image_embeddings
  1234. def get_prompt_embeddings(
  1235. self,
  1236. input_points: tf.Tensor | None = None,
  1237. input_labels: tf.Tensor | None = None,
  1238. input_boxes: tf.Tensor | None = None,
  1239. input_masks: tf.Tensor | None = None,
  1240. ):
  1241. r"""
  1242. Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder.
  1243. Args:
  1244. input_points (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`):
  1245. Optional input points for the prompt encoder. The padding of the point is automatically done by the
  1246. processor. `point_batch_size` refers to the number of masks that we want the model to predict per
  1247. point. The model will output `point_batch_size` times 3 masks in total.
  1248. input_labels (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points_per_image)`):
  1249. Optional input labels for the prompt encoder. The padding of the labels is automatically done by the
  1250. processor, or can be fed by the user.
  1251. input_boxes (`tf.Tensor` of shape `(batch_size, num_boxes_per_image, 4)`):
  1252. Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the
  1253. processor. users can also pass manually the input boxes.
  1254. input_masks (`tf.Tensor` of shape `(batch_size, image_size, image_size)`):
  1255. Optional input masks for the prompt encoder.
  1256. """
  1257. prompt_output = self.prompt_encoder(
  1258. input_points=input_points,
  1259. input_labels=input_labels,
  1260. input_boxes=input_boxes,
  1261. input_masks=input_masks,
  1262. )
  1263. return prompt_output
  1264. @unpack_inputs
  1265. @add_start_docstrings_to_model_forward(SAM_INPUTS_DOCSTRING)
  1266. def call(
  1267. self,
  1268. pixel_values: TFModelInputType | None = None,
  1269. input_points: tf.Tensor | None = None,
  1270. input_labels: tf.Tensor | None = None,
  1271. input_boxes: tf.Tensor | None = None,
  1272. input_masks: tf.Tensor | None = None,
  1273. image_embeddings: tf.Tensor | None = None,
  1274. multimask_output: bool = True,
  1275. output_attentions: bool | None = None,
  1276. output_hidden_states: bool | None = None,
  1277. return_dict: bool | None = None,
  1278. training: bool = False,
  1279. **kwargs,
  1280. ) -> TFSamImageSegmentationOutput | Tuple[tf.Tensor]:
  1281. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1282. output_hidden_states = (
  1283. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1284. )
  1285. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1286. if pixel_values is None and image_embeddings is None:
  1287. raise ValueError("Either pixel_values or image_embeddings must be provided.")
  1288. if pixel_values is not None and image_embeddings is not None:
  1289. raise ValueError("Only one of pixel_values and image_embeddings can be provided.")
  1290. if input_points is not None and len(input_points.shape) != 4:
  1291. raise ValueError(
  1292. "The input_points must be a 4D tensor. Of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.",
  1293. " got {}.".format(input_points.shape),
  1294. )
  1295. if input_boxes is not None and len(input_boxes.shape) != 3:
  1296. raise ValueError(
  1297. "The input_points must be a 3D tensor. Of shape `batch_size`, `nb_boxes`, `4`.",
  1298. " got {}.".format(input_boxes.shape),
  1299. )
  1300. if input_points is not None and input_boxes is not None:
  1301. point_batch_size = shape_list(input_points)[1]
  1302. box_batch_size = shape_list(input_boxes)[1]
  1303. if point_batch_size != box_batch_size:
  1304. raise ValueError(
  1305. "You should provide as many bounding boxes as input points per box. Got {} and {}.".format(
  1306. point_batch_size, box_batch_size
  1307. )
  1308. )
  1309. if pixel_values is not None:
  1310. # Ensures that later checks pass even with an all-None shape from the serving signature
  1311. pixel_values = tf.ensure_shape(
  1312. pixel_values,
  1313. [
  1314. None,
  1315. self.config.vision_config.num_channels,
  1316. self.config.vision_config.image_size,
  1317. self.config.vision_config.image_size,
  1318. ],
  1319. )
  1320. image_positional_embeddings = self.get_image_wide_positional_embeddings()
  1321. # repeat with batch size
  1322. batch_size = shape_list(pixel_values)[0] if pixel_values is not None else shape_list(image_embeddings)[0]
  1323. image_positional_embeddings = tf.repeat(image_positional_embeddings, batch_size, axis=0)
  1324. vision_attentions = None
  1325. vision_hidden_states = None
  1326. if pixel_values is not None:
  1327. vision_outputs = self.vision_encoder(
  1328. pixel_values,
  1329. output_attentions=output_attentions,
  1330. output_hidden_states=output_hidden_states,
  1331. return_dict=True,
  1332. training=training,
  1333. )
  1334. image_embeddings = vision_outputs["last_hidden_state"]
  1335. if output_hidden_states:
  1336. vision_hidden_states = vision_outputs["hidden_states"]
  1337. if output_attentions:
  1338. vision_attentions = vision_outputs["attentions"]
  1339. if input_points is not None and input_labels is None:
  1340. input_labels = tf.ones_like(input_points[:, :, :, 0], dtype=tf.int32)
  1341. if input_points is not None and image_embeddings.shape[0] != input_points.shape[0]:
  1342. raise ValueError(
  1343. "The batch size of the image embeddings and the input points must be the same. ",
  1344. "Got {} and {} respectively.".format(image_embeddings.shape[0], input_points.shape[0]),
  1345. " if you want to pass multiple points for the same image, make sure that you passed ",
  1346. " input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ",
  1347. " input_labels of shape (batch_size, point_batch_size, num_points_per_image)",
  1348. )
  1349. sparse_embeddings, dense_embeddings = self.prompt_encoder(
  1350. batch_size=shape_list(image_embeddings)[0],
  1351. input_points=input_points,
  1352. input_labels=input_labels,
  1353. input_boxes=input_boxes,
  1354. input_masks=input_masks,
  1355. )
  1356. low_res_masks, iou_predictions, mask_decoder_attentions = self.mask_decoder(
  1357. image_embeddings=image_embeddings,
  1358. image_positional_embeddings=image_positional_embeddings,
  1359. sparse_prompt_embeddings=sparse_embeddings,
  1360. dense_prompt_embeddings=dense_embeddings,
  1361. multimask_output=multimask_output,
  1362. output_attentions=output_attentions,
  1363. )
  1364. if not return_dict:
  1365. output = (iou_predictions, low_res_masks)
  1366. if output_hidden_states:
  1367. output = output + (vision_hidden_states,)
  1368. if output_attentions:
  1369. output = output + (vision_attentions, mask_decoder_attentions)
  1370. return output
  1371. return TFSamImageSegmentationOutput(
  1372. iou_scores=iou_predictions,
  1373. pred_masks=low_res_masks,
  1374. vision_hidden_states=vision_hidden_states,
  1375. vision_attentions=vision_attentions,
  1376. mask_decoder_attentions=mask_decoder_attentions,
  1377. )
  1378. def serving_output(self, output: TFSamImageSegmentationOutput) -> TFSamImageSegmentationOutput:
  1379. hs = tf.convert_to_tensor(output.vision_hidden_states) if self.config.output_hidden_states else None
  1380. attns = tf.convert_to_tensor(output.vision_attentions) if self.config.output_attentions else None
  1381. return TFSamImageSegmentationOutput(
  1382. iou_scores=output.iou_scores,
  1383. pred_masks=output.pred_masks,
  1384. vision_hidden_states=hs if self.config.output_hidden_states else None,
  1385. vision_attentions=attns if self.config.output_attentions else None,
  1386. mask_decoder_attentions=output.mask_decoder_attentions if self.config.output_attentions else None,
  1387. )
  1388. def build(self, input_shape=None):
  1389. if self.built:
  1390. return
  1391. self.built = True
  1392. if getattr(self, "shared_image_embedding", None) is not None:
  1393. with tf.name_scope(self.shared_image_embedding.name):
  1394. self.shared_image_embedding.build(None)
  1395. if getattr(self, "vision_encoder", None) is not None:
  1396. with tf.name_scope(self.vision_encoder.name):
  1397. self.vision_encoder.build(None)
  1398. if getattr(self, "prompt_encoder", None) is not None:
  1399. with tf.name_scope(self.prompt_encoder.name):
  1400. self.prompt_encoder.build(None)
  1401. if getattr(self, "mask_decoder", None) is not None:
  1402. with tf.name_scope(self.mask_decoder.name):
  1403. self.mask_decoder.build(None)