perceiver_tf.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. # This code was adapted from https://github.com/lucidrains/flamingo-pytorch licensed under the MIT License.
  2. #
  3. # MIT License
  4. #
  5. # Copyright (c) 2020 The Google AI Language Team Authors, The HuggingFace Inc. team and github/lonePatient
  6. #
  7. # Permission is hereby granted, free of charge, to any person obtaining a copy
  8. # of this software and associated documentation files (the "Software"), to deal
  9. # in the Software without restriction, including without limitation the rights
  10. # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  11. # copies of the Software, and to permit persons to whom the Software is
  12. # furnished to do so, subject to the following conditions:
  13. #
  14. # The above copyright notice and this permission notice shall be included in all
  15. # copies or substantial portions of the Software.
  16. #
  17. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  18. # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  19. # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  20. # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  21. # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  22. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  23. # SOFTWARE.
  24. """
  25. Generic interface to various configurations of the Perceiver Resampler, that simply takes in a series of (potentially
  26. time-indexed) contextual embeddings, and "resamples" (compresses) them down to a pre-specified number of latents! Note
  27. that the Perceiver in general resamples based solely off the *long-range* context; there's a nice opportunity here to
  28. prime the Perceiver Resampler with say a single layer's worth of language embeddings (the target domain), and use that
  29. to softly "retrieve & compress" what we need --> this would be a novel contribution we should explore.
  30. References:
  31. - DeepMind's Flamingo: https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model
  32. - Code borrowed w/ love from: https://github.com/lucidrains/flamingo-pytorch
  33. """
  34. from typing import Optional, Tuple
  35. import tensorflow as tf
  36. from ...modeling_tf_utils import shape_list
  37. from .configuration_idefics import IdeficsConfig
  38. class TFIdeficsPerceiverResampler(tf.keras.layers.Layer):
  39. def __init__(
  40. self, config: IdeficsConfig, embed_dim: int, depth: int, n_heads: int, head_dim: int, n_latents: int, **kwargs
  41. ) -> None:
  42. """
  43. Instantiates a Perceiver Resampler that operates over a sequence of embeddings (say from a ResNet or ViT or
  44. MAE) of a given dimension, performs `depth` blocks of cross-attention with a fixed `n_latents` inputs, then
  45. returns a Tensor of shape [bsz, n_latents, embed_dim]. :param embed_dim: Dimensionality of embeddings being fed
  46. to the Perceiver Resampler (also dimensionality of latent embeddings *returned* by the Perceiver Resampler.
  47. Could be e.g., VIT embed_dim, ResNet pool dim, and so on.
  48. Args:
  49. config (`IdeficsConfig`): config object
  50. embed_dim (`int`): The size of each embedding vector
  51. depth (`int`): Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (< 3).
  52. n_heads (`int`): Number of heads in each Transformer block (for multi-headed self-attention).
  53. head_dim (`int`): Dimensionality of each head projection in the Transformer block.
  54. n_latents (`int`):
  55. Number of latent embeddings to resample ("compress") the input sequence to (usually < 128).
  56. """
  57. super().__init__(**kwargs)
  58. self.embed_dim, self.n_heads, self.head_dim, self.n_latents = embed_dim, n_heads, head_dim, n_latents
  59. self.qk_layer_norms = config.perceiver_config.qk_layer_norms_perceiver
  60. self.intermediate_dim = (
  61. self.embed_dim * 4
  62. if not hasattr(config.vision_config, "embed_dim")
  63. else config.vision_config.embed_dim * 4
  64. )
  65. # Create Transformer Blocks
  66. self.blocks = []
  67. for i in range(depth):
  68. self.blocks.append(
  69. [
  70. TFIdeficsPerceiverAttention(
  71. self.embed_dim, self.n_heads, self.head_dim, self.qk_layer_norms, name=f"blocks.{i}.0"
  72. ),
  73. TFIdeficsMLP(self.intermediate_dim, config, name=f"blocks.{i}.1"),
  74. ]
  75. )
  76. self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layer_norm")
  77. def build(self, input_shape):
  78. # Create Latents for Perceiver
  79. self.latents = self.add_weight(
  80. shape=(self.n_latents, self.embed_dim), initializer="random_normal", trainable=True, name="latents"
  81. )
  82. super().build(input_shape)
  83. def call(self, context: tf.Tensor) -> tf.Tensor:
  84. """Resample arbitrary length context & *compress* down to self.n_latents latent embeddings"""
  85. # tf.repeat(self.latents, "seq embed -> bsz seq embed", bsz=context.shape[0])
  86. latents = tf.expand_dims(self.latents, axis=0)
  87. latents = tf.tile(latents, [tf.shape(context)[0], 1, 1])
  88. # Feed through Perceiver Attention blocks...
  89. for attn, ff in self.blocks:
  90. latents = attn(context, latents) + latents
  91. latents = ff(latents) + latents
  92. return self.layer_norm(latents)
  93. class TFIdeficsPerceiverAttention(tf.keras.layers.Layer):
  94. def __init__(self, embed_dim: int, n_heads: int, head_dim: int, qk_layer_norms: bool, **kwargs) -> None:
  95. """Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`"""
  96. super().__init__(**kwargs)
  97. self.embed_dim, self.n_heads, self.head_dim = embed_dim, n_heads, head_dim
  98. self.qk_layer_norms = qk_layer_norms
  99. # Normalization & Scaling
  100. self.context_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="context_layer_norm")
  101. self.latents_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="latents_layer_norm")
  102. if self.qk_layer_norms:
  103. self.q_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="q_layer_norm")
  104. self.k_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="k_layer_norm")
  105. self.qk_scale = self.head_dim**-0.5
  106. # Q, K, V Projection (no bias -- detail from Perceiver/Flamingo Papers).
  107. self.q_proj = tf.keras.layers.Dense(self.n_heads * self.head_dim, use_bias=False, name="q_proj")
  108. self.k_proj = tf.keras.layers.Dense(self.n_heads * self.head_dim, use_bias=False, name="k_proj")
  109. self.v_proj = tf.keras.layers.Dense(self.n_heads * self.head_dim, use_bias=False, name="v_proj")
  110. self.output_proj = tf.keras.layers.Dense(embed_dim, use_bias=False, name="output_proj")
  111. def call(self, context: tf.Tensor, latents: tf.Tensor) -> tf.Tensor:
  112. """
  113. Runs Perceiver Self-Attention, with special (context, latents) appended along the `seq` dimension!
  114. Args:
  115. context (`tf.Tensor`):
  116. Tensor of shape `[bsz, seq, embed_dim]` representing long-form context to resample.
  117. latents (`tf.Tensor`):
  118. Tensor of shape `[bsz, n_latents, embed_dim]` representing fixed length latents to compress to.
  119. Returns:
  120. `tf.Tensor`: Tensor of shape `[bsz, n_latents, embed_dim]` representing attention over latents w/ cross
  121. from context.
  122. """
  123. context = self.context_layer_norm(context)
  124. latents = self.latents_layer_norm(latents)
  125. batch_size, seq_length, embed_dim = shape_list(context)
  126. # Query, Key, Value Projections --> Note that in Flamingo, latents are *concatenated* with context prior to attn!
  127. # Note: This results in queries w/ `seq = n_latents`, and keys, values with `seq = len(context) + n_latents`
  128. q = self.q_proj(latents)
  129. k = self.k_proj(tf.concat([context, latents], axis=-2))
  130. v = self.v_proj(tf.concat([context, latents], axis=-2))
  131. # Multiheaded Self-Attention w/ stable softmax (subtract per-row max -- `amax` -- before softmax call)
  132. # =>> `attn` should be a 2D matrix of shape [n_latents x (context + n_latents)]
  133. q, k, v = [
  134. tf.transpose(tf.reshape(x, (batch_size, x.shape[1], self.n_heads, self.head_dim)), perm=[0, 2, 1, 3])
  135. for x in (q, k, v)
  136. ]
  137. if self.qk_layer_norms:
  138. q = self.q_layer_norm(q)
  139. k = self.k_layer_norm(k)
  140. scores = tf.einsum("... i d, ... j d -> ... i j", q * self.qk_scale, k)
  141. stabilized_scores = scores - tf.reduce_max(scores, axis=-1, keepdims=True)
  142. attn = tf.nn.softmax(stabilized_scores, axis=-1)
  143. # Attend & project back to output...
  144. resampled = tf.einsum("... i j, ... j d -> ... i d", attn, v)
  145. return self.output_proj(
  146. tf.reshape(tf.transpose(resampled, perm=[0, 2, 1, 3]), (batch_size, -1, self.n_heads * self.head_dim))
  147. )
  148. class TFIdeficsMLP(tf.keras.layers.Layer):
  149. def __init__(self, intermediate_size, config: IdeficsConfig, **kwargs):
  150. """Simple MLP block with intermediate_size and embedding size"""
  151. super().__init__(**kwargs)
  152. self.embed_dim = config.vision_config.embed_dim
  153. self.ln = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="ln")
  154. self.fc = tf.keras.layers.Dense(intermediate_size, use_bias=False, name="fc")
  155. self.act = tf.keras.layers.ReLU(name="act")
  156. self.c_proj = tf.keras.layers.Dense(self.embed_dim, use_bias=False, name="c_proj")
  157. def call(self, hidden_states: Optional[Tuple[tf.Tensor]]) -> tf.Tensor:
  158. hidden_states = self.ln(hidden_states)
  159. hidden_states = self.fc(hidden_states)
  160. hidden_states = self.act(hidden_states)
  161. hidden_states = self.c_proj(hidden_states)
  162. return hidden_states