| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652 |
- # coding=utf-8
- # Copyright 2023 The Meta AI Authors and The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """
- TensorFlow SAM model. This file was mostly generated by auto-translation from the PyTorch original. In the event of a
- discrepancy, the original file should be regarded as the 'reference' version.
- """
- from __future__ import annotations
- import collections
- from dataclasses import dataclass
- from typing import Optional, Tuple, Union
- import numpy as np
- import tensorflow as tf
- from ...activations_tf import ACT2FN
- from ...modeling_tf_outputs import TFBaseModelOutput
- from ...modeling_tf_utils import TFModelInputType, TFPreTrainedModel, keras, shape_list, unpack_inputs
- from ...tf_utils import flatten, functional_layernorm
- from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging
- from .configuration_sam import SamConfig, SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig
- logger = logging.get_logger(__name__)
- _CONFIG_FOR_DOC = "SamConfig"
- _CHECKPOINT_FOR_DOC = "facebook/sam-vit-huge"
- @dataclass
- class TFSamVisionEncoderOutput(ModelOutput):
- """
- Base class for sam vision model's outputs that also contains image embeddings obtained by applying the projection
- layer to the pooler_output.
- Args:
- image_embeds (`tf.Tensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
- The image embeddings obtained by applying the projection layer to the pooler_output.
- last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
- Sequence of hidden-states at the output of the last layer of the model.
- hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
- the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
- Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
- attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
- heads.
- """
- image_embeds: tf.Tensor | None = None
- last_hidden_state: tf.Tensor = None
- hidden_states: Tuple[tf.Tensor, ...] | None = None
- attentions: Tuple[tf.Tensor, ...] | None = None
- @dataclass
- class TFSamImageSegmentationOutput(ModelOutput):
- """
- Base class for Segment-Anything model's output
- Args:
- iou_scores (`tf.Tensor` of shape `(batch_size, num_masks)`):
- The iou scores of the predicted masks.
- pred_masks (`tf.Tensor` of shape `(batch_size, num_masks, height, width)`):
- The predicted low resolutions masks. Needs to be post-processed by the processor
- vision_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
- the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
- Hidden-states of the vision model at the output of each layer plus the optional initial embedding outputs.
- vision_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
- heads.
- mask_decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
- heads.
- """
- iou_scores: tf.Tensor = None
- pred_masks: tf.Tensor = None
- vision_hidden_states: Tuple[tf.Tensor, ...] | None = None
- vision_attentions: Tuple[tf.Tensor, ...] | None = None
- mask_decoder_attentions: Tuple[tf.Tensor, ...] | None = None
- class TFSamPatchEmbeddings(keras.layers.Layer):
- """
- This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
- `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
- Transformer.
- """
- def __init__(self, config, **kwargs):
- super().__init__(**kwargs)
- image_size, patch_size = config.image_size, config.patch_size
- num_channels, hidden_size = config.num_channels, config.hidden_size
- image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
- patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
- num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
- self.image_size = image_size
- self.patch_size = patch_size
- self.num_channels = num_channels
- self.num_patches = num_patches
- self.projection = keras.layers.Conv2D(
- hidden_size, kernel_size=patch_size, strides=patch_size, name="projection"
- )
- def call(self, pixel_values):
- batch_size, num_channels, height, width = shape_list(pixel_values)
- if num_channels != self.num_channels:
- raise ValueError(
- "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
- )
- if height != self.image_size[0] or width != self.image_size[1]:
- raise ValueError(
- f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
- )
- embeddings = self.projection(tf.transpose(pixel_values, perm=[0, 2, 3, 1]))
- return embeddings
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if getattr(self, "projection", None) is not None:
- with tf.name_scope(self.projection.name):
- self.projection.build([None, None, None, self.num_channels])
- class TFSamMLPBlock(keras.layers.Layer):
- def __init__(self, config, **kwargs):
- super().__init__(**kwargs)
- self.lin1 = keras.layers.Dense(config.mlp_dim, name="lin1")
- self.lin2 = keras.layers.Dense(config.hidden_size, name="lin2")
- self.act = ACT2FN[config.hidden_act]
- self.config = config
- def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
- hidden_states = self.lin1(hidden_states)
- hidden_states = self.act(hidden_states)
- hidden_states = self.lin2(hidden_states)
- return hidden_states
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if getattr(self, "lin1", None) is not None:
- with tf.name_scope(self.lin1.name):
- self.lin1.build([None, None, self.config.hidden_size])
- if getattr(self, "lin2", None) is not None:
- with tf.name_scope(self.lin2.name):
- self.lin2.build([None, None, self.config.mlp_dim])
- class TFSamLayerNorm(keras.layers.Layer):
- r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
- The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
- width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
- """
- def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", **kwargs):
- super().__init__(**kwargs)
- self.eps = eps
- self.data_format = data_format
- self.normalized_shape = normalized_shape
- if self.data_format not in ["channels_last", "channels_first"]:
- raise NotImplementedError(f"Unsupported data format: {self.data_format}")
- def build(self, input_shape):
- self.weight = self.add_weight(shape=self.normalized_shape, initializer="ones", name="weight")
- self.bias = self.add_weight(shape=self.normalized_shape, initializer="zeros", name="bias")
- super().build(input_shape)
- def call(self, x: tf.Tensor) -> tf.Tensor:
- if self.data_format == "channels_last":
- x = functional_layernorm(x, weight=self.weight, bias=self.bias, epsilon=self.eps, axis=-1)
- elif self.data_format == "channels_first":
- x = functional_layernorm(x, weight=self.weight, bias=self.bias, epsilon=self.eps, axis=1)
- return x
- class TFSamAttention(keras.layers.Layer):
- """
- SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
- values.
- """
- def __init__(self, config, downsample_rate=None, **kwargs):
- super().__init__(**kwargs)
- self.hidden_size = config.hidden_size
- downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
- self.internal_dim = config.hidden_size // downsample_rate
- self.num_attention_heads = config.num_attention_heads
- if self.internal_dim % config.num_attention_heads != 0:
- raise ValueError("num_attention_heads must divide hidden_size.")
- self.q_proj = keras.layers.Dense(self.internal_dim, name="q_proj")
- self.k_proj = keras.layers.Dense(self.internal_dim, name="k_proj")
- self.v_proj = keras.layers.Dense(self.internal_dim, name="v_proj")
- self.out_proj = keras.layers.Dense(self.hidden_size, name="out_proj")
- def _separate_heads(self, hidden_states: tf.Tensor, num_attention_heads: int) -> tf.Tensor:
- batch, point_batch_size, n_tokens, channel = shape_list(hidden_states)
- c_per_head = channel // num_attention_heads
- hidden_states = tf.reshape(
- hidden_states, (batch * point_batch_size, n_tokens, num_attention_heads, c_per_head)
- )
- return tf.transpose(hidden_states, perm=[0, 2, 1, 3])
- def _recombine_heads(self, hidden_states: tf.Tensor, point_batch_size: int) -> tf.Tensor:
- batch, n_heads, n_tokens, c_per_head = shape_list(hidden_states)
- hidden_states = tf.transpose(hidden_states, perm=[0, 2, 1, 3])
- return tf.reshape(
- hidden_states,
- (batch // tf.reduce_max([1, point_batch_size]), point_batch_size, n_tokens, n_heads * c_per_head),
- )
- def call(self, query: tf.Tensor, key: tf.Tensor, value: tf.Tensor) -> tf.Tensor:
- # Input projections
- query = self.q_proj(query)
- key = self.k_proj(key)
- value = self.v_proj(value)
- point_batch_size = shape_list(query)[1]
- # Separate into heads
- query = self._separate_heads(query, self.num_attention_heads)
- key = self._separate_heads(key, self.num_attention_heads)
- value = self._separate_heads(value, self.num_attention_heads)
- # SamAttention
- _, _, _, c_per_head = shape_list(query)
- attn = tf.matmul(
- query, tf.transpose(key, perm=[0, 1, 3, 2])
- ) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens
- attn = attn / tf.math.sqrt(float(c_per_head))
- attn = tf.nn.softmax(attn, axis=-1)
- # Get output
- out = tf.matmul(attn, value)
- out = self._recombine_heads(out, point_batch_size)
- out = self.out_proj(out)
- return out
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if getattr(self, "q_proj", None) is not None:
- with tf.name_scope(self.q_proj.name):
- self.q_proj.build([None, None, self.hidden_size])
- if getattr(self, "k_proj", None) is not None:
- with tf.name_scope(self.k_proj.name):
- self.k_proj.build([None, None, self.hidden_size])
- if getattr(self, "v_proj", None) is not None:
- with tf.name_scope(self.v_proj.name):
- self.v_proj.build([None, None, self.hidden_size])
- if getattr(self, "out_proj", None) is not None:
- with tf.name_scope(self.out_proj.name):
- self.out_proj.build([None, None, self.internal_dim])
- class TFSamTwoWayAttentionBlock(keras.layers.Layer):
- def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False, **kwargs):
- """
- A transformer block with four layers:
- (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on
- sparse inputs (4) cross attention of dense inputs -> sparse inputs
- Arguments:
- config (`SamMaskDecoderConfig`):
- The configuration file used to instantiate the block
- attention_downsample_rate (*optionalk*, int, defaults to 2):
- The downsample ratio of the block used to reduce the inner dim of the attention.
- skip_first_layer_pe (*optional*, bool, defaults to `False`):
- Whether or not to skip the addition of the query_point_embedding on the first layer.
- """
- super().__init__(**kwargs)
- self.hidden_size = config.hidden_size
- self.layer_norm_eps = config.layer_norm_eps
- self.self_attn = TFSamAttention(config, downsample_rate=1, name="self_attn")
- self.layer_norm1 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm1")
- self.cross_attn_token_to_image = TFSamAttention(
- config, downsample_rate=attention_downsample_rate, name="cross_attn_token_to_image"
- )
- self.layer_norm2 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm2")
- self.mlp = TFSamMLPBlock(config, name="mlp")
- self.layer_norm3 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm3")
- self.layer_norm4 = keras.layers.LayerNormalization(epsilon=self.layer_norm_eps, name="layer_norm4")
- self.cross_attn_image_to_token = TFSamAttention(
- config, downsample_rate=attention_downsample_rate, name="cross_attn_image_to_token"
- )
- self.skip_first_layer_pe = skip_first_layer_pe
- def call(
- self,
- queries: tf.Tensor,
- keys: tf.Tensor,
- query_point_embedding: tf.Tensor,
- key_point_embedding: tf.Tensor,
- output_attentions: bool = False,
- ):
- # Self attention block
- if self.skip_first_layer_pe:
- queries = self.self_attn(query=queries, key=queries, value=queries)
- else:
- query = queries + query_point_embedding
- attn_out = self.self_attn(query=query, key=query, value=queries)
- queries = queries + attn_out
- queries = self.layer_norm1(queries)
- # Cross attention block, tokens attending to image embedding
- query = queries + query_point_embedding
- key = keys + key_point_embedding
- attn_out = self.cross_attn_token_to_image(query=query, key=key, value=keys)
- queries = queries + attn_out
- queries = self.layer_norm2(queries)
- # MLP block
- mlp_out = self.mlp(queries)
- queries = queries + mlp_out
- queries = self.layer_norm3(queries)
- # Cross attention block, image embedding attending to tokens
- query = queries + query_point_embedding
- key = keys + key_point_embedding
- attn_out = self.cross_attn_image_to_token(query=key, key=query, value=queries)
- keys = keys + attn_out
- keys = self.layer_norm4(keys)
- outputs = (queries, keys)
- if output_attentions:
- outputs = outputs + (attn_out,)
- else:
- outputs = outputs + (None,)
- return outputs
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if getattr(self, "self_attn", None) is not None:
- with tf.name_scope(self.self_attn.name):
- self.self_attn.build(None)
- if getattr(self, "layer_norm1", None) is not None:
- with tf.name_scope(self.layer_norm1.name):
- self.layer_norm1.build([None, None, None, self.hidden_size])
- if getattr(self, "cross_attn_token_to_image", None) is not None:
- with tf.name_scope(self.cross_attn_token_to_image.name):
- self.cross_attn_token_to_image.build(None)
- if getattr(self, "layer_norm2", None) is not None:
- with tf.name_scope(self.layer_norm2.name):
- self.layer_norm2.build([None, None, None, self.hidden_size])
- if getattr(self, "mlp", None) is not None:
- with tf.name_scope(self.mlp.name):
- self.mlp.build(None)
- if getattr(self, "layer_norm3", None) is not None:
- with tf.name_scope(self.layer_norm3.name):
- self.layer_norm3.build([None, None, None, self.hidden_size])
- if getattr(self, "layer_norm4", None) is not None:
- with tf.name_scope(self.layer_norm4.name):
- self.layer_norm4.build([None, None, None, self.hidden_size])
- if getattr(self, "cross_attn_image_to_token", None) is not None:
- with tf.name_scope(self.cross_attn_image_to_token.name):
- self.cross_attn_image_to_token.build(None)
- class TFSamTwoWayTransformer(keras.layers.Layer):
- def __init__(self, config: SamMaskDecoderConfig, **kwargs):
- super().__init__(**kwargs)
- self.config = config
- self.num_hidden_layers = config.num_hidden_layers
- self.layers = []
- for i in range(self.num_hidden_layers):
- self.layers.append(TFSamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0), name=f"layers_._{i}"))
- self.final_attn_token_to_image = TFSamAttention(config, name="final_attn_token_to_image")
- self.layer_norm_final_attn = keras.layers.LayerNormalization(
- epsilon=config.layer_norm_eps, name="layer_norm_final_attn"
- )
- def call(
- self,
- point_embeddings: tf.Tensor,
- image_embeddings: tf.Tensor,
- image_positional_embeddings: tf.Tensor,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, TFBaseModelOutput]:
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- all_attentions = ()
- if image_embeddings is None:
- raise ValueError("You have to specify an image_embedding")
- image_embeddings = tf.transpose(flatten(image_embeddings, 2), perm=(0, 2, 1))[:, None]
- image_positional_embeddings = tf.transpose(flatten(image_positional_embeddings, 2), (0, 2, 1))[:, None]
- # Prepare queries
- queries = point_embeddings
- keys = image_embeddings
- # Apply transformer blocks and final layernorm
- for layer in self.layers:
- queries, keys, attention_outputs = layer(
- queries=queries,
- keys=keys,
- query_point_embedding=point_embeddings,
- key_point_embedding=image_positional_embeddings,
- output_attentions=output_attentions,
- )
- if output_attentions:
- all_attentions = all_attentions + (attention_outputs,)
- # Apply the final attenion layer from the points to the image
- query = queries + point_embeddings
- key = keys + image_positional_embeddings
- attn_out = self.final_attn_token_to_image(query=query, key=key, value=keys)
- queries = queries + attn_out
- queries = self.layer_norm_final_attn(queries)
- return queries, keys, all_attentions
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if getattr(self, "final_attn_token_to_image", None) is not None:
- with tf.name_scope(self.final_attn_token_to_image.name):
- self.final_attn_token_to_image.build(None)
- if getattr(self, "layer_norm_final_attn", None) is not None:
- with tf.name_scope(self.layer_norm_final_attn.name):
- self.layer_norm_final_attn.build([None, None, None, self.config.hidden_size])
- for layer in self.layers:
- with tf.name_scope(layer.name):
- layer.build(None)
- class TFSamFeedForward(keras.layers.Layer):
- def __init__(
- self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False, **kwargs
- ):
- super().__init__(**kwargs)
- self.num_layers = num_layers
- self.activation = keras.layers.ReLU()
- self.proj_in = keras.layers.Dense(hidden_dim, input_shape=(input_dim,), name="proj_in")
- self.proj_out = keras.layers.Dense(output_dim, input_shape=(hidden_dim,), name="proj_out")
- self.layers = [
- keras.layers.Dense(hidden_dim, input_shape=(hidden_dim,), name=f"layers_._{i}")
- for i in range(num_layers - 2)
- ]
- self.sigmoid_output = sigmoid_output
- self.hidden_dim = hidden_dim
- self.input_dim = input_dim
- def call(self, hidden_states):
- hidden_states = self.proj_in(hidden_states)
- hidden_states = self.activation(hidden_states)
- for layer in self.layers:
- hidden_states = self.activation(layer(hidden_states))
- hidden_states = self.proj_out(hidden_states)
- if self.sigmoid_output:
- hidden_states = tf.sigmoid(hidden_states)
- return hidden_states
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if getattr(self, "proj_in", None) is not None:
- with tf.name_scope(self.proj_in.name):
- self.proj_in.build([None, None, self.input_dim])
- if getattr(self, "proj_out", None) is not None:
- with tf.name_scope(self.proj_out.name):
- self.proj_out.build([None, None, self.hidden_dim])
- if getattr(self, "layers", None) is not None:
- for layer in self.layers:
- with tf.name_scope(layer.name):
- layer.build([None, None, self.hidden_dim])
- class TFSamMaskDecoder(keras.layers.Layer):
- def __init__(self, config: SamMaskDecoderConfig, **kwargs):
- super().__init__(**kwargs)
- self.hidden_size = config.hidden_size
- self.num_multimask_outputs = config.num_multimask_outputs
- self.num_mask_tokens = config.num_multimask_outputs + 1
- self.transformer = TFSamTwoWayTransformer(config, name="transformer")
- self.upscale_conv1 = keras.layers.Conv2DTranspose(
- self.hidden_size // 4, kernel_size=2, strides=2, name="upscale_conv1", data_format="channels_first"
- )
- self.upscale_conv2 = keras.layers.Conv2DTranspose(
- self.hidden_size // 8, kernel_size=2, strides=2, name="upscale_conv2", data_format="channels_first"
- )
- self.upscale_layer_norm = TFSamLayerNorm(
- self.hidden_size // 4, data_format="channels_first", name="upscale_layer_norm"
- )
- self.activation = tf.nn.gelu
- mlps_list = []
- for i in range(self.num_mask_tokens):
- mlps_list += [
- TFSamFeedForward(
- self.hidden_size,
- self.hidden_size,
- self.hidden_size // 8,
- 3,
- name=f"output_hypernetworks_mlps_._{i}",
- )
- ]
- self.output_hypernetworks_mlps = mlps_list
- self.iou_prediction_head = TFSamFeedForward(
- self.hidden_size,
- config.iou_head_hidden_dim,
- self.num_mask_tokens,
- config.iou_head_depth,
- name="iou_prediction_head",
- )
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- self.iou_token = self.add_weight(shape=(1, self.hidden_size), name="iou_token.weight", trainable=True)
- self.mask_tokens = self.add_weight(
- shape=(self.num_mask_tokens, self.hidden_size), name="mask_tokens.weight", trainable=True
- )
- if getattr(self, "transformer", None) is not None:
- with tf.name_scope(self.transformer.name):
- self.transformer.build(None)
- if getattr(self, "upscale_conv1", None) is not None:
- with tf.name_scope(self.upscale_conv1.name):
- self.upscale_conv1.build([None, self.hidden_size, None, None])
- if getattr(self, "upscale_conv2", None) is not None:
- with tf.name_scope(self.upscale_conv2.name):
- self.upscale_conv2.build([None, self.hidden_size // 4, None, None])
- if getattr(self, "upscale_layer_norm", None) is not None:
- with tf.name_scope(self.upscale_layer_norm.name):
- self.upscale_layer_norm.build(None)
- if getattr(self, "iou_prediction_head", None) is not None:
- with tf.name_scope(self.iou_prediction_head.name):
- self.iou_prediction_head.build(None)
- for mlp in self.output_hypernetworks_mlps:
- with tf.name_scope(mlp.name):
- mlp.build(None)
- def call(
- self,
- image_embeddings: tf.Tensor,
- image_positional_embeddings: tf.Tensor,
- sparse_prompt_embeddings: tf.Tensor,
- dense_prompt_embeddings: tf.Tensor,
- multimask_output: bool,
- output_attentions: Optional[bool] = None,
- ) -> Tuple[tf.Tensor, tf.Tensor]:
- batch_size, num_channels, height, width = shape_list(image_embeddings)
- point_batch_size = tf.math.maximum(1, tf.shape(sparse_prompt_embeddings)[1])
- output_tokens = tf.concat([self.iou_token, self.mask_tokens], axis=0) # Should be (1, 32) + (4, 32) = (5, 32)
- output_tokens = tf.tile(
- output_tokens[None, None, :], [batch_size, point_batch_size, 1, 1]
- ) # Should be (batch_size, point_size, 5, 32)
- # Matt: The original Torch code checked that the sum of sparse_prompt_embeddings equalled 0. However, this only
- # happens when the sparse prompt embeddings are an empty tensor with shape[1] == 0. I replaced
- # it with an explicit shape check to avoid data-dependent control flow which breaks XLA.
- if shape_list(sparse_prompt_embeddings)[1] != 0:
- tokens = tf.concat((output_tokens, sparse_prompt_embeddings), axis=2)
- else:
- tokens = output_tokens
- point_embeddings = tf.cast(tokens, self.iou_token.dtype)
- image_embeddings = image_embeddings + dense_prompt_embeddings
- image_embeddings = tf.repeat(image_embeddings, point_batch_size, axis=0)
- image_positional_embeddings = tf.repeat(image_positional_embeddings, point_batch_size, axis=0)
- point_embedding, image_embeddings, attentions = self.transformer(
- point_embeddings=point_embeddings,
- image_embeddings=image_embeddings,
- image_positional_embeddings=image_positional_embeddings,
- output_attentions=output_attentions,
- )
- iou_token_out = point_embedding[:, :, 0, :]
- mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :]
- image_embeddings = tf.transpose(image_embeddings, perm=(0, 1, 3, 2))
- image_embeddings = tf.reshape(image_embeddings, [batch_size * point_batch_size, num_channels, height, width])
- upscaled_embedding = self.upscale_conv1(image_embeddings)
- upscaled_embedding = self.activation(self.upscale_layer_norm(upscaled_embedding))
- upscaled_embedding = self.activation(self.upscale_conv2(upscaled_embedding))
- hyper_in_list = []
- for i in range(self.num_mask_tokens):
- current_mlp = self.output_hypernetworks_mlps[i]
- hyper_in_list += [current_mlp(mask_tokens_out[:, :, i, :])]
- hyper_in = tf.stack(hyper_in_list, axis=2)
- _, num_channels, height, width = shape_list(upscaled_embedding)
- upscaled_embedding = tf.reshape(
- upscaled_embedding, [batch_size, point_batch_size, num_channels, height * width]
- )
- masks = tf.reshape(hyper_in @ upscaled_embedding, [batch_size, point_batch_size, -1, height, width])
- iou_pred = self.iou_prediction_head(iou_token_out)
- if multimask_output:
- mask_slice = slice(1, None)
- else:
- mask_slice = slice(0, 1)
- masks = masks[:, :, mask_slice, :, :]
- iou_pred = iou_pred[:, :, mask_slice]
- outputs = (masks, iou_pred)
- if output_attentions:
- outputs = outputs + (attentions,)
- else:
- outputs = outputs + (None,)
- return outputs
- class TFSamPositionalEmbedding(keras.layers.Layer):
- def __init__(self, config, **kwargs):
- super().__init__(**kwargs)
- self.scale = config.hidden_size // 2
- self.config = config
- def build(self, input_shape):
- # TODO Matt: What is going on here? Why is a non-trainable weight randomly initialized?
- self.positional_embedding = self.add_weight(
- name="positional_embedding",
- shape=(2, self.config.num_pos_feats),
- initializer=keras.initializers.RandomNormal(mean=0.0, stddev=self.scale),
- trainable=False,
- )
- super().build(input_shape)
- def call(self, input_coords, input_shape=None):
- """Positionally encode points that are normalized to [0,1]."""
- coordinates = tf.identity(input_coords)
- if input_shape is not None:
- coordinates = tf.stack(
- [
- tf.cast(coordinates[:, :, :, 0], tf.float32) / input_shape[1],
- tf.cast(coordinates[:, :, :, 1], tf.float32) / input_shape[0],
- ],
- axis=-1,
- )
- # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
- coordinates = 2 * coordinates - 1
- coordinates = tf.cast(coordinates, self.positional_embedding.dtype)
- coordinates = tf.matmul(coordinates, self.positional_embedding)
- coordinates = 2 * np.pi * coordinates
- # outputs d_1 x ... x d_n x channel shape
- return tf.concat([tf.sin(coordinates), tf.cos(coordinates)], axis=-1)
- class TFSamMaskEmbedding(keras.layers.Layer):
- def __init__(self, config: SamPromptEncoderConfig, **kwargs):
- super().__init__(**kwargs)
- self.mask_input_channels = config.mask_input_channels // 4
- self.activation = ACT2FN[config.hidden_act]
- self.conv1 = keras.layers.Conv2D(self.mask_input_channels, kernel_size=2, strides=2, name="conv1")
- self.conv2 = keras.layers.Conv2D(config.mask_input_channels, kernel_size=2, strides=2, name="conv2")
- self.conv3 = keras.layers.Conv2D(config.hidden_size, kernel_size=1, name="conv3")
- self.layer_norm1 = TFSamLayerNorm(self.mask_input_channels, config.layer_norm_eps, name="layer_norm1")
- self.layer_norm2 = TFSamLayerNorm(self.mask_input_channels * 4, config.layer_norm_eps, name="layer_norm2")
- self.config = config
- def call(self, masks):
- masks = tf.transpose(masks, perm=(0, 2, 3, 1)) # Convert to channels-last
- hidden_states = self.conv1(masks)
- hidden_states = self.layer_norm1(hidden_states)
- hidden_states = self.activation(hidden_states)
- hidden_states = self.conv2(hidden_states)
- hidden_states = self.layer_norm2(hidden_states)
- hidden_states = self.activation(hidden_states)
- dense_embeddings = self.conv3(hidden_states)
- dense_embeddings = tf.transpose(dense_embeddings, perm=(0, 3, 1, 2)) # Convert back to channels-first
- return dense_embeddings
- def build(self, input_shape=None):
- # This class needs an explicit build method because it isn't called with the standard dummy inputs
- if self.built:
- return
- self.built = True
- with tf.name_scope("conv1"):
- self.conv1.build([None, None, None, 1])
- with tf.name_scope("conv2"):
- self.conv2.build([None, None, None, self.mask_input_channels])
- with tf.name_scope("conv3"):
- self.conv3.build([None, None, None, self.mask_input_channels * 4])
- with tf.name_scope("layer_norm1"):
- self.layer_norm1.build([None, None, None, self.mask_input_channels])
- with tf.name_scope("layer_norm2"):
- self.layer_norm2.build([None, None, None, self.mask_input_channels * 4])
- class TFSamPromptEncoder(keras.layers.Layer):
- def __init__(self, config: SamPromptEncoderConfig, shared_patch_embedding, **kwargs):
- super().__init__(**kwargs)
- self.shared_embedding = shared_patch_embedding
- self.mask_embed = TFSamMaskEmbedding(config, name="mask_embed")
- self.no_mask_embed = None
- self.image_embedding_size = (config.image_embedding_size, config.image_embedding_size)
- self.input_image_size = config.image_size
- self.point_embed = []
- self.hidden_size = config.hidden_size
- self.not_a_point_embed = None
- self.config = config
- def build(self, input_shape=None):
- self.no_mask_embed = self.add_weight(
- name="no_mask_embed.weight",
- shape=(1, self.hidden_size),
- initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02),
- trainable=True,
- )
- self.point_embed = [
- self.add_weight(
- name=f"point_embed_._{i}.weight",
- shape=(1, self.hidden_size),
- initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02),
- trainable=True,
- )
- for i in range(self.config.num_point_embeddings)
- ]
- self.not_a_point_embed = self.add_weight(
- name="not_a_point_embed.weight",
- shape=(1, self.hidden_size),
- initializer=keras.initializers.RandomNormal(mean=0.0, stddev=0.02),
- trainable=True,
- )
- with tf.name_scope("mask_embed"):
- # We must explicitly build the mask embed because it isn't touched by the standard dummy inputs
- self.mask_embed.build(
- (None, self.config.mask_input_channels, self.config.image_size, self.config.image_size)
- )
- if self.built:
- return
- self.built = True
- if getattr(self, "mask_embed", None) is not None:
- with tf.name_scope(self.mask_embed.name):
- self.mask_embed.build(None)
- def _embed_points(self, points: tf.Tensor, labels: tf.Tensor, pad: bool) -> tf.Tensor:
- """Embeds point prompts."""
- points = points + 0.5 # Shift to center of pixel
- if pad:
- target_point_shape = (shape_list(points)[0], shape_list(points)[1], 1, shape_list(points)[-1])
- target_labels_shape = (shape_list(points)[0], shape_list(points)[1], 1)
- padding_point = tf.zeros(target_point_shape, dtype=points.dtype)
- padding_label = -tf.ones(target_labels_shape, dtype=labels.dtype)
- points = tf.concat([points, padding_point], axis=2)
- labels = tf.concat([labels, padding_label], axis=2)
- input_shape = (self.input_image_size, self.input_image_size)
- point_embedding = self.shared_embedding(points, input_shape)
- point_embedding = tf.where(labels[..., None] == -1, self.not_a_point_embed[0], point_embedding)
- point_embedding = tf.where(
- labels[..., None] != -10,
- point_embedding,
- tf.zeros_like(point_embedding),
- )
- point_embedding = tf.where(
- (labels == 0)[:, :, :, None], point_embedding + self.point_embed[0], point_embedding
- )
- point_embedding = tf.where(
- (labels == 1)[:, :, :, None], point_embedding + self.point_embed[1], point_embedding
- )
- return point_embedding
- def _embed_boxes(self, boxes: tf.Tensor) -> tf.Tensor:
- """Embeds box prompts."""
- boxes = boxes + 0.5 # Shift to center of pixel
- batch_size, nb_boxes = shape_list(boxes)[:2]
- coords = tf.reshape(boxes, (batch_size, nb_boxes, 2, 2))
- input_shape = (self.input_image_size, self.input_image_size)
- corner_embedding = self.shared_embedding(coords, input_shape)
- corner_embedding += tf.where(
- tf.range(shape_list(corner_embedding)[2])[None, None, :, None] == 0,
- self.point_embed[2][0],
- self.point_embed[3][0],
- )
- return corner_embedding
- def call(
- self,
- batch_size: Optional[int],
- input_points: Optional[Tuple[tf.Tensor, tf.Tensor]],
- input_labels: tf.Tensor | None,
- input_boxes: tf.Tensor | None,
- input_masks: tf.Tensor | None,
- ) -> Tuple[tf.Tensor, tf.Tensor]:
- """
- Embeds different types of prompts, returning both sparse and dense embeddings.
- Args:
- points (`tf.Tensor`, *optional*):
- point coordinates and labels to embed.
- boxes (`tf.Tensor`, *optional*):
- boxes to embed
- masks (`tf.Tensor`, *optional*):
- masks to embed
- """
- sparse_embeddings = None
- if input_points is not None:
- batch_size, point_batch_size = shape_list(input_points)[:2]
- if input_labels is None:
- raise ValueError("If points are provided, labels must also be provided.")
- point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
- sparse_embeddings = tf.zeros(
- (batch_size, point_batch_size, 0, self.hidden_size), dtype=point_embeddings.dtype
- )
- sparse_embeddings = tf.concat([sparse_embeddings, point_embeddings], axis=2)
- if input_boxes is not None:
- batch_size = shape_list(input_boxes)[0]
- box_embeddings = self._embed_boxes(input_boxes)
- if sparse_embeddings is None:
- sparse_embeddings = box_embeddings
- else:
- sparse_embeddings = tf.concat([sparse_embeddings, box_embeddings], axis=2)
- if input_masks is not None:
- dense_embeddings = self.mask_embed(input_masks)
- else:
- dense_embeddings = self.no_mask_embed[0]
- dense_embeddings = tf.reshape(dense_embeddings, (1, -1, 1, 1))
- dense_embeddings = tf.tile(
- dense_embeddings, (batch_size, 1, self.image_embedding_size[0], self.image_embedding_size[1])
- )
- if sparse_embeddings is None:
- sparse_embeddings = tf.zeros((batch_size, 0, 1, self.hidden_size), dtype=dense_embeddings.dtype)
- return sparse_embeddings, dense_embeddings
- class TFSamVisionAttention(keras.layers.Layer):
- """Multi-head Attention block with relative position embeddings."""
- def __init__(self, config, window_size, **kwargs):
- super().__init__(**kwargs)
- input_size = (
- (config.image_size // config.patch_size, config.image_size // config.patch_size)
- if window_size == 0
- else (window_size, window_size)
- )
- self.input_size = input_size
- self.num_attention_heads = config.num_attention_heads
- head_dim = config.hidden_size // config.num_attention_heads
- self.head_dim = head_dim
- self.scale = head_dim**-0.5
- self.dropout = config.attention_dropout
- self.qkv = keras.layers.Dense(config.hidden_size * 3, use_bias=config.qkv_bias, name="qkv")
- self.proj = keras.layers.Dense(config.hidden_size, name="proj")
- self.use_rel_pos = config.use_rel_pos
- if self.use_rel_pos:
- if input_size is None:
- raise ValueError("Input size must be provided if using relative positional encoding.")
- self.config = config
- def build(self, input_shape=None):
- if self.input_size is not None:
- # initialize relative positional embeddings
- self.rel_pos_h = self.add_weight(
- shape=(2 * self.input_size[0] - 1, self.head_dim), initializer="zeros", name="rel_pos_h"
- )
- self.rel_pos_w = self.add_weight(
- shape=(2 * self.input_size[1] - 1, self.head_dim), initializer="zeros", name="rel_pos_w"
- )
- if self.built:
- return
- self.built = True
- if getattr(self, "qkv", None) is not None:
- with tf.name_scope(self.qkv.name):
- self.qkv.build([None, None, self.config.hidden_size])
- if getattr(self, "proj", None) is not None:
- with tf.name_scope(self.proj.name):
- self.proj.build([None, None, self.config.hidden_size])
- def get_rel_pos(self, q_size: int, k_size: int, rel_pos: tf.Tensor) -> tf.Tensor:
- """
- Get relative positional embeddings according to the relative positions of
- query and key sizes.
- Args:
- q_size (int):
- size of the query.
- k_size (int):
- size of key k.
- rel_pos (`tf.Tensor`):
- relative position embeddings (L, channel).
- Returns:
- Extracted positional embeddings according to relative positions.
- """
- max_rel_dist = int(2 * max(q_size, k_size) - 1)
- # Interpolate rel pos if needed.
- if rel_pos.shape[0] != max_rel_dist:
- # Interpolate rel pos.
- rel_pos_resized = tf.image.resize(
- tf.reshape(rel_pos, (1, rel_pos.shape[0], -1)),
- size=(max_rel_dist, rel_pos.shape[1]),
- method="bilinear",
- )
- rel_pos_resized = tf.reshape(rel_pos_resized, (-1, max_rel_dist))
- else:
- rel_pos_resized = rel_pos
- # Scale the coords with short length if shapes for q and k are different.
- q_coords = tf.expand_dims(tf.range(q_size, dtype=tf.float32), 1) * max(k_size / q_size, 1.0)
- k_coords = tf.expand_dims(tf.range(k_size, dtype=tf.float32), 0) * max(q_size / k_size, 1.0)
- relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
- return tf.gather(rel_pos_resized, tf.cast(relative_coords, tf.int32))
- def add_decomposed_rel_pos(
- self,
- attn: tf.Tensor,
- query: tf.Tensor,
- rel_pos_h: tf.Tensor,
- rel_pos_w: tf.Tensor,
- q_size: Tuple[int, int],
- k_size: Tuple[int, int],
- ) -> tf.Tensor:
- """
- Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
- https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
- Args:
- attn (`tf.Tensor`):
- attention map.
- query (`tf.Tensor`):
- query q in the attention layer with shape (batch_size, query_height * query_width, channel).
- rel_pos_h (`tf.Tensor`):
- relative position embeddings (Lh, channel) for height axis.
- rel_pos_w (`tf.Tensor`):
- relative position embeddings (Lw, channel) for width axis.
- q_size (tuple):
- spatial sequence size of query q with (query_height, query_width).
- k_size (tuple):
- spatial sequence size of key k with (key_height, key_width).
- Returns:
- attn (`tf.Tensor`):
- attention map with added relative positional embeddings.
- """
- query_height, query_width = q_size
- key_height, key_width = k_size
- relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h)
- relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w)
- batch_size, _, dim = shape_list(query)
- reshaped_query = tf.reshape(query, (batch_size, query_height, query_width, dim))
- rel_h = tf.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
- rel_w = tf.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
- attn = tf.reshape(attn, (batch_size, query_height, query_width, key_height, key_width))
- attn = attn + tf.expand_dims(rel_h, axis=-1) + tf.expand_dims(rel_w, axis=-2)
- attn = tf.reshape(attn, (batch_size, query_height * query_width, key_height * key_width))
- return attn
- def call(self, hidden_states: tf.Tensor, output_attentions=False, training=False) -> tf.Tensor:
- batch_size, height, width, _ = shape_list(hidden_states)
- # qkv with shape (3, batch_size, nHead, height * width, channel)
- qkv = tf.reshape(self.qkv(hidden_states), (batch_size, height * width, 3, self.num_attention_heads, -1))
- qkv = tf.transpose(qkv, perm=(2, 0, 3, 1, 4))
- # q, k, v with shape (batch_size * nHead, height * width, channel)
- query, key, value = tf.unstack(
- tf.reshape(qkv, (3, batch_size * self.num_attention_heads, height * width, -1)), axis=0
- )
- attn_weights = tf.matmul(query * self.scale, key, transpose_b=True)
- if self.use_rel_pos:
- attn_weights = self.add_decomposed_rel_pos(
- attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
- )
- attn_weights = tf.nn.softmax(attn_weights, axis=-1)
- if training:
- attn_probs = tf.nn.dropout(attn_weights, rate=self.dropout)
- else:
- attn_probs = attn_weights
- attn_output = tf.reshape(attn_probs @ value, (batch_size, self.num_attention_heads, height, width, -1))
- attn_output = tf.transpose(attn_output, perm=(0, 2, 3, 1, 4))
- attn_output = tf.reshape(attn_output, (batch_size, height, width, self.config.hidden_size))
- attn_output = self.proj(attn_output)
- if output_attentions:
- outputs = (attn_output, attn_weights)
- else:
- outputs = (attn_output, None)
- return outputs
- class TFSamVisionLayer(keras.layers.Layer):
- def __init__(self, config, window_size, **kwargs):
- super().__init__(**kwargs)
- self.layer_norm1 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm1")
- self.attn = TFSamVisionAttention(config, window_size, name="attn")
- self.layer_norm2 = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm2")
- self.mlp = TFSamMLPBlock(config, name="mlp")
- self.window_size = window_size
- self.config = config
- def window_partition(self, hidden_states: tf.Tensor, window_size: int) -> Tuple[tf.Tensor, Tuple[int, int]]:
- batch_size, height, width, channel = shape_list(hidden_states)
- pad_h = (window_size - height % window_size) % window_size
- pad_w = (window_size - width % window_size) % window_size
- if pad_h > 0 or pad_w > 0:
- hidden_states = tf.pad(hidden_states, [[0, 0], [0, pad_h], [0, pad_w], [0, 0]])
- pad_height, pad_width = height + pad_h, width + pad_w
- hidden_states = tf.reshape(
- hidden_states,
- [batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel],
- )
- windows = tf.reshape(
- tf.transpose(hidden_states, perm=[0, 1, 3, 2, 4, 5]), [-1, window_size, window_size, channel]
- )
- return windows, (pad_height, pad_width)
- def window_unpartition(
- self, windows: tf.Tensor, window_size: int, padding_shape: Tuple[int, int], original_shape: Tuple[int, int]
- ) -> tf.Tensor:
- pad_height, pad_width = padding_shape
- height, width = original_shape
- batch_size = shape_list(windows)[0] // (pad_height * pad_width // window_size // window_size)
- hidden_states = tf.reshape(
- windows, [batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1]
- )
- hidden_states = tf.reshape(
- tf.transpose(hidden_states, perm=[0, 1, 3, 2, 4, 5]), [batch_size, pad_height, pad_width, -1]
- )
- if pad_height > height or pad_width > width:
- hidden_states = hidden_states[:, :height, :width, :]
- return hidden_states
- def call(
- self,
- hidden_states: tf.Tensor,
- output_attentions: Optional[bool] = False,
- training: Optional[bool] = False,
- ) -> Tuple[tf.Tensor]:
- residual = hidden_states
- hidden_states = self.layer_norm1(hidden_states)
- if self.window_size > 0:
- height, width = hidden_states.shape[1], hidden_states.shape[2]
- hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size)
- hidden_states, attn_weights = self.attn(
- hidden_states=hidden_states,
- output_attentions=output_attentions,
- training=training,
- )
- if self.window_size > 0:
- hidden_states = self.window_unpartition(hidden_states, self.window_size, padding_shape, (height, width))
- hidden_states = residual + hidden_states
- layernorm_output = self.layer_norm2(hidden_states)
- hidden_states = hidden_states + self.mlp(layernorm_output)
- outputs = (hidden_states,)
- if output_attentions:
- outputs += (attn_weights,)
- return outputs
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if getattr(self, "layer_norm1", None) is not None:
- with tf.name_scope(self.layer_norm1.name):
- self.layer_norm1.build([None, None, None, self.config.hidden_size])
- if getattr(self, "attn", None) is not None:
- with tf.name_scope(self.attn.name):
- self.attn.build(None)
- if getattr(self, "layer_norm2", None) is not None:
- with tf.name_scope(self.layer_norm2.name):
- self.layer_norm2.build([None, None, None, self.config.hidden_size])
- if getattr(self, "mlp", None) is not None:
- with tf.name_scope(self.mlp.name):
- self.mlp.build(None)
- class TFSamVisionNeck(keras.layers.Layer):
- def __init__(self, config: SamVisionConfig, **kwargs):
- super().__init__(**kwargs)
- self.config = config
- self.conv1 = keras.layers.Conv2D(
- config.output_channels,
- kernel_size=1,
- use_bias=False,
- name="conv1",
- )
- self.layer_norm1 = TFSamLayerNorm(config.output_channels, name="layer_norm1")
- self.conv2 = keras.layers.Conv2D(
- config.output_channels,
- kernel_size=3,
- padding="same",
- use_bias=False,
- name="conv2",
- )
- self.layer_norm2 = TFSamLayerNorm(config.output_channels, name="layer_norm2")
- def call(self, hidden_states):
- hidden_states = self.conv1(hidden_states)
- hidden_states = self.layer_norm1(hidden_states)
- hidden_states = self.conv2(hidden_states)
- hidden_states = self.layer_norm2(hidden_states)
- hidden_states = tf.transpose(hidden_states, perm=[0, 3, 1, 2])
- return hidden_states
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if getattr(self, "conv1", None) is not None:
- with tf.name_scope(self.conv1.name):
- self.conv1.build([None, None, None, self.config.hidden_size])
- if getattr(self, "layer_norm1", None) is not None:
- with tf.name_scope(self.layer_norm1.name):
- self.layer_norm1.build(None)
- if getattr(self, "conv2", None) is not None:
- with tf.name_scope(self.conv2.name):
- self.conv2.build([None, None, None, self.config.output_channels])
- if getattr(self, "layer_norm2", None) is not None:
- with tf.name_scope(self.layer_norm2.name):
- self.layer_norm2.build(None)
- class TFSamVisionEncoder(keras.layers.Layer):
- def __init__(self, config: SamVisionConfig, **kwargs):
- super().__init__(**kwargs)
- self.config = config
- self.image_size = config.image_size
- self.patch_embed = TFSamPatchEmbeddings(config, name="patch_embed")
- self.pos_embed = None
- self.layers = []
- for i in range(config.num_hidden_layers):
- layer = TFSamVisionLayer(
- config,
- window_size=config.window_size if i not in config.global_attn_indexes else 0,
- name=f"layers_._{i}",
- )
- self.layers.append(layer)
- self.neck = TFSamVisionNeck(config, name="neck")
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if self.config.use_abs_pos:
- # Initialize absolute positional embedding with pretrain image size.
- self.pos_embed = self.add_weight(
- shape=[
- 1,
- self.config.image_size // self.config.patch_size,
- self.config.image_size // self.config.patch_size,
- self.config.hidden_size,
- ],
- initializer="zeros",
- trainable=True,
- name="pos_embed",
- )
- if getattr(self, "patch_embed", None) is not None:
- with tf.name_scope(self.patch_embed.name):
- self.patch_embed.build(None)
- if getattr(self, "neck", None) is not None:
- with tf.name_scope(self.neck.name):
- self.neck.build(None)
- for layer in self.layers:
- with tf.name_scope(layer.name):
- layer.build(None)
- def get_input_embeddings(self):
- return self.patch_embed
- def call(
- self,
- pixel_values: tf.Tensor | None = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- training: Optional[bool] = False,
- ) -> Union[Tuple, TFSamVisionEncoderOutput]:
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if pixel_values is None:
- raise ValueError("You have to specify pixel_values")
- hidden_states = self.patch_embed(pixel_values)
- if self.pos_embed is not None:
- hidden_states = hidden_states + self.pos_embed
- all_hidden_states = () if output_hidden_states else None
- all_self_attentions = () if output_attentions else None
- for i, layer_module in enumerate(self.layers):
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- layer_outputs = layer_module(hidden_states, output_attentions=output_attentions, training=training)
- hidden_states = layer_outputs[0]
- if output_attentions:
- all_self_attentions = all_self_attentions + (layer_outputs[1],)
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- hidden_states = self.neck(hidden_states)
- if not return_dict:
- outputs = (hidden_states,)
- if output_hidden_states:
- outputs = outputs + (all_hidden_states,)
- if output_attentions:
- outputs = outputs + (all_self_attentions,)
- return outputs
- return TFSamVisionEncoderOutput(
- last_hidden_state=hidden_states,
- hidden_states=all_hidden_states,
- attentions=all_self_attentions,
- )
- class TFSamPreTrainedModel(TFPreTrainedModel):
- config_class = SamConfig
- base_model_prefix = "sam"
- main_input_name = "pixel_values"
- SAM_START_DOCSTRING = r"""
- This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
- etc.)
- This model is also a TensorFlow [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model)
- subclass. Use it as a regular TensorFlow Model and refer to the TensorFlow documentation for all matter related to
- general usage and behavior.
- Parameters:
- config ([`SamConfig`]): Model configuration class with all the parameters of the model.
- Initializing with a config file does not load the weights associated with the model, only the
- configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
- """
- SAM_INPUTS_DOCSTRING = r"""
- Args:
- pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
- Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for
- details.
- input_points (`tf.Tensor` of shape `(batch_size, num_points, 2)`):
- Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much
- better results. The points can be obtained by passing a list of list of list to the processor that will
- create corresponding `tf` tensors of dimension 4. The first dimension is the image batch size, the second
- dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict per
- input point), the third dimension is the number of points per segmentation mask (it is possible to pass
- multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal)
- coordinates of the point. If a different number of points is passed either for each image, or for each
- mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the
- computation of the embedding will be skipped for these points using the labels.
- input_labels (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points)`):
- Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the
- official implementation, there are 3 types of labels
- - `1`: the point is a point that contains the object of interest
- - `0`: the point is a point that does not contain the object of interest
- - `-1`: the point corresponds to the background
- We added the label:
- - `-10`: the point is a padding point, thus should be ignored by the prompt encoder
- The padding labels should be automatically done by the processor.
- input_boxes (`tf.Tensor` of shape `(batch_size, num_boxes, 4)`):
- Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to
- much better generated masks. The boxes can be obtained by passing a list of list of list to the processor,
- that will generate a `tf` tensor, with each dimension corresponding respectively to the image batch size,
- the number of boxes per image and the coordinates of the top left and botton right point of the box. In the
- order (`x1`, `y1`, `x2`, `y2`):
- - `x1`: the x coordinate of the top left point of the input box
- - `y1`: the y coordinate of the top left point of the input box
- - `x2`: the x coordinate of the bottom right point of the input box
- - `y2`: the y coordinate of the bottom right point of the input box
- input_masks (`tf.Tensor` of shape `(batch_size, image_size, image_size)`):
- SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to
- generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be
- manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`).
- image_embeddings (`tf.Tensor` of shape `(batch_size, output_channels, window_size, window_size)`):
- Image embeddings, this is used by the mask decder to generate masks and iou scores. For more memory
- efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings`
- method, and then feed them to the `call` method instead of feeding the `pixel_values`.
- multimask_output (`bool`, *optional*):
- In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
- bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
- "best" mask, by specifying `multimask_output=False`.
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
- tensors for more detail.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
- more detail.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- """
- @add_start_docstrings(
- "Segment Anything Model (SAM) for generating segmentation masks, given an input image and ",
- " optional 2D location and bounding boxes.",
- SAM_START_DOCSTRING,
- )
- class TFSamModel(TFSamPreTrainedModel):
- _keys_to_ignore_on_load_missing = [r"prompt_encoder.shared_embedding.positional_embedding"]
- def __init__(self, config, **kwargs):
- super().__init__(config, **kwargs)
- self.shared_image_embedding = TFSamPositionalEmbedding(config.vision_config, name="shared_image_embedding")
- self.vision_encoder = TFSamVisionEncoder(config.vision_config, name="vision_encoder")
- self.prompt_encoder = TFSamPromptEncoder(
- config.prompt_encoder_config, self.shared_image_embedding, name="prompt_encoder"
- )
- self.mask_decoder = TFSamMaskDecoder(config.mask_decoder_config, name="mask_decoder")
- self.config = config
- def get_input_embeddings(self):
- return self.vision_encoder.get_input_embeddings()
- def get_image_wide_positional_embeddings(self):
- size = self.config.prompt_encoder_config.image_embedding_size
- grid = tf.ones((size, size))
- y_embed = tf.math.cumsum(grid, axis=0) - 0.5
- x_embed = tf.math.cumsum(grid, axis=1) - 0.5
- y_embed = y_embed / size
- x_embed = x_embed / size
- positional_embedding = self.shared_image_embedding(tf.stack([x_embed, y_embed], axis=-1))
- return tf.expand_dims(tf.transpose(positional_embedding, perm=[2, 0, 1]), axis=0) # channel x height x width
- def get_image_embeddings(
- self,
- pixel_values,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ):
- r"""
- Returns the image embeddings by passing the pixel values through the vision encoder.
- Args:
- pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
- Input pixel values
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.TFModelOutput`] instead of a plain tuple.
- """
- vision_output = self.vision_encoder(
- pixel_values,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- image_embeddings = vision_output[0]
- return image_embeddings
- def get_prompt_embeddings(
- self,
- input_points: tf.Tensor | None = None,
- input_labels: tf.Tensor | None = None,
- input_boxes: tf.Tensor | None = None,
- input_masks: tf.Tensor | None = None,
- ):
- r"""
- Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder.
- Args:
- input_points (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`):
- Optional input points for the prompt encoder. The padding of the point is automatically done by the
- processor. `point_batch_size` refers to the number of masks that we want the model to predict per
- point. The model will output `point_batch_size` times 3 masks in total.
- input_labels (`tf.Tensor` of shape `(batch_size, point_batch_size, num_points_per_image)`):
- Optional input labels for the prompt encoder. The padding of the labels is automatically done by the
- processor, or can be fed by the user.
- input_boxes (`tf.Tensor` of shape `(batch_size, num_boxes_per_image, 4)`):
- Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the
- processor. users can also pass manually the input boxes.
- input_masks (`tf.Tensor` of shape `(batch_size, image_size, image_size)`):
- Optional input masks for the prompt encoder.
- """
- prompt_output = self.prompt_encoder(
- input_points=input_points,
- input_labels=input_labels,
- input_boxes=input_boxes,
- input_masks=input_masks,
- )
- return prompt_output
- @unpack_inputs
- @add_start_docstrings_to_model_forward(SAM_INPUTS_DOCSTRING)
- def call(
- self,
- pixel_values: TFModelInputType | None = None,
- input_points: tf.Tensor | None = None,
- input_labels: tf.Tensor | None = None,
- input_boxes: tf.Tensor | None = None,
- input_masks: tf.Tensor | None = None,
- image_embeddings: tf.Tensor | None = None,
- multimask_output: bool = True,
- output_attentions: bool | None = None,
- output_hidden_states: bool | None = None,
- return_dict: bool | None = None,
- training: bool = False,
- **kwargs,
- ) -> TFSamImageSegmentationOutput | Tuple[tf.Tensor]:
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if pixel_values is None and image_embeddings is None:
- raise ValueError("Either pixel_values or image_embeddings must be provided.")
- if pixel_values is not None and image_embeddings is not None:
- raise ValueError("Only one of pixel_values and image_embeddings can be provided.")
- if input_points is not None and len(input_points.shape) != 4:
- raise ValueError(
- "The input_points must be a 4D tensor. Of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.",
- " got {}.".format(input_points.shape),
- )
- if input_boxes is not None and len(input_boxes.shape) != 3:
- raise ValueError(
- "The input_points must be a 3D tensor. Of shape `batch_size`, `nb_boxes`, `4`.",
- " got {}.".format(input_boxes.shape),
- )
- if input_points is not None and input_boxes is not None:
- point_batch_size = shape_list(input_points)[1]
- box_batch_size = shape_list(input_boxes)[1]
- if point_batch_size != box_batch_size:
- raise ValueError(
- "You should provide as many bounding boxes as input points per box. Got {} and {}.".format(
- point_batch_size, box_batch_size
- )
- )
- if pixel_values is not None:
- # Ensures that later checks pass even with an all-None shape from the serving signature
- pixel_values = tf.ensure_shape(
- pixel_values,
- [
- None,
- self.config.vision_config.num_channels,
- self.config.vision_config.image_size,
- self.config.vision_config.image_size,
- ],
- )
- image_positional_embeddings = self.get_image_wide_positional_embeddings()
- # repeat with batch size
- batch_size = shape_list(pixel_values)[0] if pixel_values is not None else shape_list(image_embeddings)[0]
- image_positional_embeddings = tf.repeat(image_positional_embeddings, batch_size, axis=0)
- vision_attentions = None
- vision_hidden_states = None
- if pixel_values is not None:
- vision_outputs = self.vision_encoder(
- pixel_values,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=True,
- training=training,
- )
- image_embeddings = vision_outputs["last_hidden_state"]
- if output_hidden_states:
- vision_hidden_states = vision_outputs["hidden_states"]
- if output_attentions:
- vision_attentions = vision_outputs["attentions"]
- if input_points is not None and input_labels is None:
- input_labels = tf.ones_like(input_points[:, :, :, 0], dtype=tf.int32)
- if input_points is not None and image_embeddings.shape[0] != input_points.shape[0]:
- raise ValueError(
- "The batch size of the image embeddings and the input points must be the same. ",
- "Got {} and {} respectively.".format(image_embeddings.shape[0], input_points.shape[0]),
- " if you want to pass multiple points for the same image, make sure that you passed ",
- " input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ",
- " input_labels of shape (batch_size, point_batch_size, num_points_per_image)",
- )
- sparse_embeddings, dense_embeddings = self.prompt_encoder(
- batch_size=shape_list(image_embeddings)[0],
- input_points=input_points,
- input_labels=input_labels,
- input_boxes=input_boxes,
- input_masks=input_masks,
- )
- low_res_masks, iou_predictions, mask_decoder_attentions = self.mask_decoder(
- image_embeddings=image_embeddings,
- image_positional_embeddings=image_positional_embeddings,
- sparse_prompt_embeddings=sparse_embeddings,
- dense_prompt_embeddings=dense_embeddings,
- multimask_output=multimask_output,
- output_attentions=output_attentions,
- )
- if not return_dict:
- output = (iou_predictions, low_res_masks)
- if output_hidden_states:
- output = output + (vision_hidden_states,)
- if output_attentions:
- output = output + (vision_attentions, mask_decoder_attentions)
- return output
- return TFSamImageSegmentationOutput(
- iou_scores=iou_predictions,
- pred_masks=low_res_masks,
- vision_hidden_states=vision_hidden_states,
- vision_attentions=vision_attentions,
- mask_decoder_attentions=mask_decoder_attentions,
- )
- def serving_output(self, output: TFSamImageSegmentationOutput) -> TFSamImageSegmentationOutput:
- hs = tf.convert_to_tensor(output.vision_hidden_states) if self.config.output_hidden_states else None
- attns = tf.convert_to_tensor(output.vision_attentions) if self.config.output_attentions else None
- return TFSamImageSegmentationOutput(
- iou_scores=output.iou_scores,
- pred_masks=output.pred_masks,
- vision_hidden_states=hs if self.config.output_hidden_states else None,
- vision_attentions=attns if self.config.output_attentions else None,
- mask_decoder_attentions=output.mask_decoder_attentions if self.config.output_attentions else None,
- )
- def build(self, input_shape=None):
- if self.built:
- return
- self.built = True
- if getattr(self, "shared_image_embedding", None) is not None:
- with tf.name_scope(self.shared_image_embedding.name):
- self.shared_image_embedding.build(None)
- if getattr(self, "vision_encoder", None) is not None:
- with tf.name_scope(self.vision_encoder.name):
- self.vision_encoder.build(None)
- if getattr(self, "prompt_encoder", None) is not None:
- with tf.name_scope(self.prompt_encoder.name):
- self.prompt_encoder.build(None)
- if getattr(self, "mask_decoder", None) is not None:
- with tf.name_scope(self.mask_decoder.name):
- self.mask_decoder.build(None)
|