modeling_tf_vit.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904
  1. # coding=utf-8
  2. # Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """TF 2.0 ViT model."""
  16. from __future__ import annotations
  17. import collections.abc
  18. import math
  19. from typing import Optional, Tuple, Union
  20. import numpy as np
  21. import tensorflow as tf
  22. from ...activations_tf import get_tf_activation
  23. from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling, TFSequenceClassifierOutput
  24. from ...modeling_tf_utils import (
  25. TFModelInputType,
  26. TFPreTrainedModel,
  27. TFSequenceClassificationLoss,
  28. get_initializer,
  29. keras,
  30. keras_serializable,
  31. unpack_inputs,
  32. )
  33. from ...tf_utils import shape_list, stable_softmax
  34. from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
  35. from .configuration_vit import ViTConfig
  36. logger = logging.get_logger(__name__)
  37. # General docstring
  38. _CONFIG_FOR_DOC = "ViTConfig"
  39. # Base docstring
  40. _CHECKPOINT_FOR_DOC = "google/vit-base-patch16-224-in21k"
  41. _EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
  42. # Image classification docstring
  43. _IMAGE_CLASS_CHECKPOINT = "google/vit-base-patch16-224"
  44. _IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat"
  45. class TFViTEmbeddings(keras.layers.Layer):
  46. """
  47. Construct the CLS token, position and patch embeddings.
  48. """
  49. def __init__(self, config: ViTConfig, **kwargs):
  50. super().__init__(**kwargs)
  51. self.patch_embeddings = TFViTPatchEmbeddings(config, name="patch_embeddings")
  52. self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
  53. self.config = config
  54. def build(self, input_shape=None):
  55. num_patches = self.patch_embeddings.num_patches
  56. self.cls_token = self.add_weight(
  57. shape=(1, 1, self.config.hidden_size),
  58. initializer=get_initializer(self.config.initializer_range),
  59. trainable=True,
  60. name="cls_token",
  61. )
  62. self.position_embeddings = self.add_weight(
  63. shape=(1, num_patches + 1, self.config.hidden_size),
  64. initializer=get_initializer(self.config.initializer_range),
  65. trainable=True,
  66. name="position_embeddings",
  67. )
  68. if self.built:
  69. return
  70. self.built = True
  71. if getattr(self, "patch_embeddings", None) is not None:
  72. with tf.name_scope(self.patch_embeddings.name):
  73. self.patch_embeddings.build(None)
  74. def interpolate_pos_encoding(self, embeddings, height, width) -> tf.Tensor:
  75. """
  76. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
  77. resolution images.
  78. Source:
  79. https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
  80. """
  81. batch_size, seq_len, dim = shape_list(embeddings)
  82. num_patches = seq_len - 1
  83. _, num_positions, _ = shape_list(self.position_embeddings)
  84. num_positions -= 1
  85. if num_patches == num_positions and height == width:
  86. return self.position_embeddings
  87. class_pos_embed = self.position_embeddings[:, :1]
  88. patch_pos_embed = self.position_embeddings[:, 1:]
  89. h0 = height // self.config.patch_size
  90. w0 = width // self.config.patch_size
  91. patch_pos_embed = tf.image.resize(
  92. images=tf.reshape(
  93. patch_pos_embed, shape=(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
  94. ),
  95. size=(h0, w0),
  96. method="bicubic",
  97. )
  98. shape = shape_list(patch_pos_embed)
  99. assert h0 == shape[-3] and w0 == shape[-2]
  100. patch_pos_embed = tf.reshape(tensor=patch_pos_embed, shape=(1, -1, dim))
  101. return tf.concat(values=(class_pos_embed, patch_pos_embed), axis=1)
  102. def call(
  103. self, pixel_values: tf.Tensor, interpolate_pos_encoding: bool = False, training: bool = False
  104. ) -> tf.Tensor:
  105. batch_size, num_channels, height, width = shape_list(pixel_values)
  106. embeddings = self.patch_embeddings(
  107. pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, training=training
  108. )
  109. # add the [CLS] token to the embedded patch tokens
  110. cls_tokens = tf.repeat(self.cls_token, repeats=batch_size, axis=0)
  111. embeddings = tf.concat((cls_tokens, embeddings), axis=1)
  112. # add positional encoding to each token
  113. if interpolate_pos_encoding:
  114. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  115. else:
  116. embeddings = embeddings + self.position_embeddings
  117. embeddings = self.dropout(embeddings, training=training)
  118. return embeddings
  119. # Based on timm implementation, which can be found here:
  120. # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
  121. class TFViTPatchEmbeddings(keras.layers.Layer):
  122. """
  123. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  124. `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
  125. Transformer.
  126. """
  127. def __init__(self, config: ViTConfig, **kwargs):
  128. super().__init__(**kwargs)
  129. image_size, patch_size = config.image_size, config.patch_size
  130. num_channels, hidden_size = config.num_channels, config.hidden_size
  131. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  132. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  133. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  134. self.image_size = image_size
  135. self.patch_size = patch_size
  136. self.num_patches = num_patches
  137. self.num_channels = num_channels
  138. self.config = config
  139. self.projection = keras.layers.Conv2D(
  140. filters=hidden_size,
  141. kernel_size=patch_size,
  142. strides=patch_size,
  143. padding="valid",
  144. data_format="channels_last",
  145. use_bias=True,
  146. kernel_initializer=get_initializer(self.config.initializer_range),
  147. bias_initializer="zeros",
  148. name="projection",
  149. )
  150. def call(
  151. self, pixel_values: tf.Tensor, interpolate_pos_encoding: bool = False, training: bool = False
  152. ) -> tf.Tensor:
  153. batch_size, num_channels, height, width = shape_list(pixel_values)
  154. if tf.executing_eagerly() and num_channels != self.num_channels:
  155. raise ValueError(
  156. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  157. )
  158. if not interpolate_pos_encoding:
  159. if tf.executing_eagerly():
  160. if height != self.image_size[0] or width != self.image_size[1]:
  161. raise ValueError(
  162. f"Input image size ({height}*{width}) doesn't match model"
  163. f" ({self.image_size[0]}*{self.image_size[1]})."
  164. )
  165. # When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format.
  166. # So change the input format from `NCHW` to `NHWC`.
  167. # shape = (batch_size, in_height, in_width, in_channels=num_channels)
  168. pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
  169. projection = self.projection(pixel_values)
  170. # Change the 2D spatial dimensions to a single temporal dimension.
  171. # shape = (batch_size, num_patches, out_channels=embed_dim)
  172. num_patches = (width // self.patch_size[1]) * (height // self.patch_size[0])
  173. embeddings = tf.reshape(tensor=projection, shape=(batch_size, num_patches, -1))
  174. return embeddings
  175. def build(self, input_shape=None):
  176. if self.built:
  177. return
  178. self.built = True
  179. if getattr(self, "projection", None) is not None:
  180. with tf.name_scope(self.projection.name):
  181. self.projection.build([None, None, None, self.num_channels])
  182. class TFViTSelfAttention(keras.layers.Layer):
  183. def __init__(self, config: ViTConfig, **kwargs):
  184. super().__init__(**kwargs)
  185. if config.hidden_size % config.num_attention_heads != 0:
  186. raise ValueError(
  187. f"The hidden size ({config.hidden_size}) is not a multiple of the number "
  188. f"of attention heads ({config.num_attention_heads})"
  189. )
  190. self.num_attention_heads = config.num_attention_heads
  191. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  192. self.all_head_size = self.num_attention_heads * self.attention_head_size
  193. self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
  194. self.query = keras.layers.Dense(
  195. units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
  196. )
  197. self.key = keras.layers.Dense(
  198. units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
  199. )
  200. self.value = keras.layers.Dense(
  201. units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
  202. )
  203. self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
  204. self.config = config
  205. def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
  206. # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
  207. tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
  208. # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
  209. return tf.transpose(tensor, perm=[0, 2, 1, 3])
  210. def call(
  211. self,
  212. hidden_states: tf.Tensor,
  213. head_mask: tf.Tensor,
  214. output_attentions: bool,
  215. training: bool = False,
  216. ) -> Tuple[tf.Tensor]:
  217. batch_size = shape_list(hidden_states)[0]
  218. mixed_query_layer = self.query(inputs=hidden_states)
  219. mixed_key_layer = self.key(inputs=hidden_states)
  220. mixed_value_layer = self.value(inputs=hidden_states)
  221. query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
  222. key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
  223. value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
  224. # Take the dot product between "query" and "key" to get the raw attention scores.
  225. # (batch size, num_heads, seq_len_q, seq_len_k)
  226. attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
  227. dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
  228. attention_scores = tf.divide(attention_scores, dk)
  229. # Normalize the attention scores to probabilities.
  230. attention_probs = stable_softmax(logits=attention_scores, axis=-1)
  231. # This is actually dropping out entire tokens to attend to, which might
  232. # seem a bit unusual, but is taken from the original Transformer paper.
  233. attention_probs = self.dropout(inputs=attention_probs, training=training)
  234. # Mask heads if we want to
  235. if head_mask is not None:
  236. attention_probs = tf.multiply(attention_probs, head_mask)
  237. attention_output = tf.matmul(attention_probs, value_layer)
  238. attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
  239. # (batch_size, seq_len_q, all_head_size)
  240. attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
  241. outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
  242. return outputs
  243. def build(self, input_shape=None):
  244. if self.built:
  245. return
  246. self.built = True
  247. if getattr(self, "query", None) is not None:
  248. with tf.name_scope(self.query.name):
  249. self.query.build([None, None, self.config.hidden_size])
  250. if getattr(self, "key", None) is not None:
  251. with tf.name_scope(self.key.name):
  252. self.key.build([None, None, self.config.hidden_size])
  253. if getattr(self, "value", None) is not None:
  254. with tf.name_scope(self.value.name):
  255. self.value.build([None, None, self.config.hidden_size])
  256. class TFViTSelfOutput(keras.layers.Layer):
  257. """
  258. The residual connection is defined in TFViTLayer instead of here (as is the case with other models), due to the
  259. layernorm applied before each block.
  260. """
  261. def __init__(self, config: ViTConfig, **kwargs):
  262. super().__init__(**kwargs)
  263. self.dense = keras.layers.Dense(
  264. units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
  265. )
  266. self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
  267. self.config = config
  268. def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
  269. hidden_states = self.dense(inputs=hidden_states)
  270. hidden_states = self.dropout(inputs=hidden_states, training=training)
  271. return hidden_states
  272. def build(self, input_shape=None):
  273. if self.built:
  274. return
  275. self.built = True
  276. if getattr(self, "dense", None) is not None:
  277. with tf.name_scope(self.dense.name):
  278. self.dense.build([None, None, self.config.hidden_size])
  279. class TFViTAttention(keras.layers.Layer):
  280. def __init__(self, config: ViTConfig, **kwargs):
  281. super().__init__(**kwargs)
  282. self.self_attention = TFViTSelfAttention(config, name="attention")
  283. self.dense_output = TFViTSelfOutput(config, name="output")
  284. def prune_heads(self, heads):
  285. raise NotImplementedError
  286. def call(
  287. self,
  288. input_tensor: tf.Tensor,
  289. head_mask: tf.Tensor,
  290. output_attentions: bool,
  291. training: bool = False,
  292. ) -> Tuple[tf.Tensor]:
  293. self_outputs = self.self_attention(
  294. hidden_states=input_tensor, head_mask=head_mask, output_attentions=output_attentions, training=training
  295. )
  296. attention_output = self.dense_output(
  297. hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
  298. )
  299. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  300. return outputs
  301. def build(self, input_shape=None):
  302. if self.built:
  303. return
  304. self.built = True
  305. if getattr(self, "self_attention", None) is not None:
  306. with tf.name_scope(self.self_attention.name):
  307. self.self_attention.build(None)
  308. if getattr(self, "dense_output", None) is not None:
  309. with tf.name_scope(self.dense_output.name):
  310. self.dense_output.build(None)
  311. class TFViTIntermediate(keras.layers.Layer):
  312. def __init__(self, config: ViTConfig, **kwargs):
  313. super().__init__(**kwargs)
  314. self.dense = keras.layers.Dense(
  315. units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
  316. )
  317. if isinstance(config.hidden_act, str):
  318. self.intermediate_act_fn = get_tf_activation(config.hidden_act)
  319. else:
  320. self.intermediate_act_fn = config.hidden_act
  321. self.config = config
  322. def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
  323. hidden_states = self.dense(inputs=hidden_states)
  324. hidden_states = self.intermediate_act_fn(hidden_states)
  325. return hidden_states
  326. def build(self, input_shape=None):
  327. if self.built:
  328. return
  329. self.built = True
  330. if getattr(self, "dense", None) is not None:
  331. with tf.name_scope(self.dense.name):
  332. self.dense.build([None, None, self.config.hidden_size])
  333. class TFViTOutput(keras.layers.Layer):
  334. def __init__(self, config: ViTConfig, **kwargs):
  335. super().__init__(**kwargs)
  336. self.dense = keras.layers.Dense(
  337. units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
  338. )
  339. self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob)
  340. self.config = config
  341. def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
  342. hidden_states = self.dense(inputs=hidden_states)
  343. hidden_states = self.dropout(inputs=hidden_states, training=training)
  344. hidden_states = hidden_states + input_tensor
  345. return hidden_states
  346. def build(self, input_shape=None):
  347. if self.built:
  348. return
  349. self.built = True
  350. if getattr(self, "dense", None) is not None:
  351. with tf.name_scope(self.dense.name):
  352. self.dense.build([None, None, self.config.intermediate_size])
  353. class TFViTLayer(keras.layers.Layer):
  354. """This corresponds to the Block class in the timm implementation."""
  355. def __init__(self, config: ViTConfig, **kwargs):
  356. super().__init__(**kwargs)
  357. self.attention = TFViTAttention(config, name="attention")
  358. self.intermediate = TFViTIntermediate(config, name="intermediate")
  359. self.vit_output = TFViTOutput(config, name="output")
  360. self.layernorm_before = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_before")
  361. self.layernorm_after = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_after")
  362. self.config = config
  363. def call(
  364. self,
  365. hidden_states: tf.Tensor,
  366. head_mask: tf.Tensor,
  367. output_attentions: bool,
  368. training: bool = False,
  369. ) -> Tuple[tf.Tensor]:
  370. attention_outputs = self.attention(
  371. # in ViT, layernorm is applied before self-attention
  372. input_tensor=self.layernorm_before(inputs=hidden_states),
  373. head_mask=head_mask,
  374. output_attentions=output_attentions,
  375. training=training,
  376. )
  377. attention_output = attention_outputs[0]
  378. # first residual connection
  379. hidden_states = attention_output + hidden_states
  380. # in ViT, layernorm is also applied after self-attention
  381. layer_output = self.layernorm_after(inputs=hidden_states)
  382. intermediate_output = self.intermediate(hidden_states=layer_output)
  383. # second residual connection is done here
  384. layer_output = self.vit_output(
  385. hidden_states=intermediate_output, input_tensor=hidden_states, training=training
  386. )
  387. outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
  388. return outputs
  389. def build(self, input_shape=None):
  390. if self.built:
  391. return
  392. self.built = True
  393. if getattr(self, "attention", None) is not None:
  394. with tf.name_scope(self.attention.name):
  395. self.attention.build(None)
  396. if getattr(self, "intermediate", None) is not None:
  397. with tf.name_scope(self.intermediate.name):
  398. self.intermediate.build(None)
  399. if getattr(self, "vit_output", None) is not None:
  400. with tf.name_scope(self.vit_output.name):
  401. self.vit_output.build(None)
  402. if getattr(self, "layernorm_before", None) is not None:
  403. with tf.name_scope(self.layernorm_before.name):
  404. self.layernorm_before.build([None, None, self.config.hidden_size])
  405. if getattr(self, "layernorm_after", None) is not None:
  406. with tf.name_scope(self.layernorm_after.name):
  407. self.layernorm_after.build([None, None, self.config.hidden_size])
  408. class TFViTEncoder(keras.layers.Layer):
  409. def __init__(self, config: ViTConfig, **kwargs):
  410. super().__init__(**kwargs)
  411. self.layer = [TFViTLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
  412. def call(
  413. self,
  414. hidden_states: tf.Tensor,
  415. head_mask: tf.Tensor,
  416. output_attentions: bool,
  417. output_hidden_states: bool,
  418. return_dict: bool,
  419. training: bool = False,
  420. ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
  421. all_hidden_states = () if output_hidden_states else None
  422. all_attentions = () if output_attentions else None
  423. for i, layer_module in enumerate(self.layer):
  424. if output_hidden_states:
  425. all_hidden_states = all_hidden_states + (hidden_states,)
  426. layer_outputs = layer_module(
  427. hidden_states=hidden_states,
  428. head_mask=head_mask[i],
  429. output_attentions=output_attentions,
  430. training=training,
  431. )
  432. hidden_states = layer_outputs[0]
  433. if output_attentions:
  434. all_attentions = all_attentions + (layer_outputs[1],)
  435. # Add last layer
  436. if output_hidden_states:
  437. all_hidden_states = all_hidden_states + (hidden_states,)
  438. if not return_dict:
  439. return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
  440. return TFBaseModelOutput(
  441. last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
  442. )
  443. def build(self, input_shape=None):
  444. if self.built:
  445. return
  446. self.built = True
  447. if getattr(self, "layer", None) is not None:
  448. for layer in self.layer:
  449. with tf.name_scope(layer.name):
  450. layer.build(None)
  451. @keras_serializable
  452. class TFViTMainLayer(keras.layers.Layer):
  453. config_class = ViTConfig
  454. def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, **kwargs):
  455. super().__init__(**kwargs)
  456. self.config = config
  457. self.embeddings = TFViTEmbeddings(config, name="embeddings")
  458. self.encoder = TFViTEncoder(config, name="encoder")
  459. self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
  460. self.pooler = TFViTPooler(config, name="pooler") if add_pooling_layer else None
  461. def get_input_embeddings(self) -> keras.layers.Layer:
  462. return self.embeddings.patch_embeddings
  463. def _prune_heads(self, heads_to_prune):
  464. """
  465. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  466. class PreTrainedModel
  467. """
  468. raise NotImplementedError
  469. @unpack_inputs
  470. def call(
  471. self,
  472. pixel_values: TFModelInputType | None = None,
  473. head_mask: np.ndarray | tf.Tensor | None = None,
  474. output_attentions: Optional[bool] = None,
  475. output_hidden_states: Optional[bool] = None,
  476. interpolate_pos_encoding: Optional[bool] = None,
  477. return_dict: Optional[bool] = None,
  478. training: bool = False,
  479. ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
  480. if pixel_values is None:
  481. raise ValueError("You have to specify pixel_values")
  482. embedding_output = self.embeddings(
  483. pixel_values=pixel_values,
  484. interpolate_pos_encoding=interpolate_pos_encoding,
  485. training=training,
  486. )
  487. # Prepare head mask if needed
  488. # 1.0 in head_mask indicate we keep the head
  489. # attention_probs has shape bsz x n_heads x N x N
  490. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  491. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  492. if head_mask is not None:
  493. raise NotImplementedError
  494. else:
  495. head_mask = [None] * self.config.num_hidden_layers
  496. encoder_outputs = self.encoder(
  497. hidden_states=embedding_output,
  498. head_mask=head_mask,
  499. output_attentions=output_attentions,
  500. output_hidden_states=output_hidden_states,
  501. return_dict=return_dict,
  502. training=training,
  503. )
  504. sequence_output = encoder_outputs[0]
  505. sequence_output = self.layernorm(inputs=sequence_output)
  506. pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None
  507. if not return_dict:
  508. return (sequence_output, pooled_output) + encoder_outputs[1:]
  509. return TFBaseModelOutputWithPooling(
  510. last_hidden_state=sequence_output,
  511. pooler_output=pooled_output,
  512. hidden_states=encoder_outputs.hidden_states,
  513. attentions=encoder_outputs.attentions,
  514. )
  515. def build(self, input_shape=None):
  516. if self.built:
  517. return
  518. self.built = True
  519. if getattr(self, "embeddings", None) is not None:
  520. with tf.name_scope(self.embeddings.name):
  521. self.embeddings.build(None)
  522. if getattr(self, "encoder", None) is not None:
  523. with tf.name_scope(self.encoder.name):
  524. self.encoder.build(None)
  525. if getattr(self, "layernorm", None) is not None:
  526. with tf.name_scope(self.layernorm.name):
  527. self.layernorm.build([None, None, self.config.hidden_size])
  528. if getattr(self, "pooler", None) is not None:
  529. with tf.name_scope(self.pooler.name):
  530. self.pooler.build(None)
  531. class TFViTPreTrainedModel(TFPreTrainedModel):
  532. """
  533. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  534. models.
  535. """
  536. config_class = ViTConfig
  537. base_model_prefix = "vit"
  538. main_input_name = "pixel_values"
  539. VIT_START_DOCSTRING = r"""
  540. This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
  541. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  542. etc.)
  543. This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
  544. as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
  545. behavior.
  546. <Tip>
  547. TensorFlow models and layers in `transformers` accept two formats as input:
  548. - having all inputs as keyword arguments (like PyTorch models), or
  549. - having all inputs as a list, tuple or dict in the first positional argument.
  550. The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
  551. and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
  552. pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
  553. format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
  554. the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
  555. positional argument:
  556. - a single Tensor with `pixel_values` only and nothing else: `model(pixel_values)`
  557. - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
  558. `model([pixel_values, attention_mask])` or `model([pixel_values, attention_mask, token_type_ids])`
  559. - a dictionary with one or several input Tensors associated to the input names given in the docstring:
  560. `model({"pixel_values": pixel_values, "token_type_ids": token_type_ids})`
  561. Note that when creating models and layers with
  562. [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
  563. about any of this, as you can just pass inputs like you would to any other Python function!
  564. </Tip>
  565. Args:
  566. config ([`ViTConfig`]): Model configuration class with all the parameters of the model.
  567. Initializing with a config file does not load the weights associated with the model, only the
  568. configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
  569. """
  570. VIT_INPUTS_DOCSTRING = r"""
  571. Args:
  572. pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
  573. Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
  574. for details.
  575. head_mask (`np.ndarray` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  576. Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
  577. - 1 indicates the head is **not masked**,
  578. - 0 indicates the head is **masked**.
  579. output_attentions (`bool`, *optional*):
  580. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  581. tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
  582. config will be used instead.
  583. output_hidden_states (`bool`, *optional*):
  584. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  585. more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
  586. used instead.
  587. interpolate_pos_encoding (`bool`, *optional*):
  588. Whether to interpolate the pre-trained position encodings.
  589. return_dict (`bool`, *optional*):
  590. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
  591. eager mode, in graph mode the value will always be set to True.
  592. training (`bool`, *optional*, defaults to `False``):
  593. Whether or not to use the model in training mode (some modules like dropout modules have different
  594. behaviors between training and evaluation).
  595. """
  596. @add_start_docstrings(
  597. "The bare ViT Model transformer outputting raw hidden-states without any specific head on top.",
  598. VIT_START_DOCSTRING,
  599. )
  600. class TFViTModel(TFViTPreTrainedModel):
  601. def __init__(self, config: ViTConfig, *inputs, add_pooling_layer=True, **kwargs):
  602. super().__init__(config, *inputs, **kwargs)
  603. self.vit = TFViTMainLayer(config, add_pooling_layer=add_pooling_layer, name="vit")
  604. @unpack_inputs
  605. @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
  606. @add_code_sample_docstrings(
  607. checkpoint=_CHECKPOINT_FOR_DOC,
  608. output_type=TFBaseModelOutputWithPooling,
  609. config_class=_CONFIG_FOR_DOC,
  610. modality="vision",
  611. expected_output=_EXPECTED_OUTPUT_SHAPE,
  612. )
  613. def call(
  614. self,
  615. pixel_values: TFModelInputType | None = None,
  616. head_mask: np.ndarray | tf.Tensor | None = None,
  617. output_attentions: Optional[bool] = None,
  618. output_hidden_states: Optional[bool] = None,
  619. interpolate_pos_encoding: Optional[bool] = None,
  620. return_dict: Optional[bool] = None,
  621. training: bool = False,
  622. ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
  623. outputs = self.vit(
  624. pixel_values=pixel_values,
  625. head_mask=head_mask,
  626. output_attentions=output_attentions,
  627. output_hidden_states=output_hidden_states,
  628. interpolate_pos_encoding=interpolate_pos_encoding,
  629. return_dict=return_dict,
  630. training=training,
  631. )
  632. return outputs
  633. def build(self, input_shape=None):
  634. if self.built:
  635. return
  636. self.built = True
  637. if getattr(self, "vit", None) is not None:
  638. with tf.name_scope(self.vit.name):
  639. self.vit.build(None)
  640. class TFViTPooler(keras.layers.Layer):
  641. def __init__(self, config: ViTConfig, **kwargs):
  642. super().__init__(**kwargs)
  643. self.dense = keras.layers.Dense(
  644. units=config.hidden_size,
  645. kernel_initializer=get_initializer(config.initializer_range),
  646. activation="tanh",
  647. name="dense",
  648. )
  649. self.config = config
  650. def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
  651. # We "pool" the model by simply taking the hidden state corresponding
  652. # to the first token.
  653. first_token_tensor = hidden_states[:, 0]
  654. pooled_output = self.dense(inputs=first_token_tensor)
  655. return pooled_output
  656. def build(self, input_shape=None):
  657. if self.built:
  658. return
  659. self.built = True
  660. if getattr(self, "dense", None) is not None:
  661. with tf.name_scope(self.dense.name):
  662. self.dense.build([None, None, self.config.hidden_size])
  663. @add_start_docstrings(
  664. """
  665. ViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
  666. the [CLS] token) e.g. for ImageNet.
  667. <Tip>
  668. Note that it's possible to fine-tune ViT on higher resolution images than the ones it has been trained on, by
  669. setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
  670. position embeddings to the higher resolution.
  671. </Tip>
  672. """,
  673. VIT_START_DOCSTRING,
  674. )
  675. class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassificationLoss):
  676. def __init__(self, config: ViTConfig, *inputs, **kwargs):
  677. super().__init__(config, *inputs, **kwargs)
  678. self.num_labels = config.num_labels
  679. self.vit = TFViTMainLayer(config, add_pooling_layer=False, name="vit")
  680. # Classifier head
  681. self.classifier = keras.layers.Dense(
  682. units=config.num_labels,
  683. kernel_initializer=get_initializer(config.initializer_range),
  684. name="classifier",
  685. )
  686. self.config = config
  687. @unpack_inputs
  688. @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
  689. @add_code_sample_docstrings(
  690. checkpoint=_IMAGE_CLASS_CHECKPOINT,
  691. output_type=TFSequenceClassifierOutput,
  692. config_class=_CONFIG_FOR_DOC,
  693. expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
  694. )
  695. def call(
  696. self,
  697. pixel_values: TFModelInputType | None = None,
  698. head_mask: np.ndarray | tf.Tensor | None = None,
  699. output_attentions: Optional[bool] = None,
  700. output_hidden_states: Optional[bool] = None,
  701. interpolate_pos_encoding: Optional[bool] = None,
  702. return_dict: Optional[bool] = None,
  703. labels: np.ndarray | tf.Tensor | None = None,
  704. training: Optional[bool] = False,
  705. ) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
  706. r"""
  707. labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
  708. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  709. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  710. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  711. """
  712. outputs = self.vit(
  713. pixel_values=pixel_values,
  714. head_mask=head_mask,
  715. output_attentions=output_attentions,
  716. output_hidden_states=output_hidden_states,
  717. interpolate_pos_encoding=interpolate_pos_encoding,
  718. return_dict=return_dict,
  719. training=training,
  720. )
  721. sequence_output = outputs[0]
  722. logits = self.classifier(inputs=sequence_output[:, 0, :])
  723. loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
  724. if not return_dict:
  725. output = (logits,) + outputs[2:]
  726. return ((loss,) + output) if loss is not None else output
  727. return TFSequenceClassifierOutput(
  728. loss=loss,
  729. logits=logits,
  730. hidden_states=outputs.hidden_states,
  731. attentions=outputs.attentions,
  732. )
  733. def build(self, input_shape=None):
  734. if self.built:
  735. return
  736. self.built = True
  737. if getattr(self, "vit", None) is not None:
  738. with tf.name_scope(self.vit.name):
  739. self.vit.build(None)
  740. if getattr(self, "classifier", None) is not None:
  741. with tf.name_scope(self.classifier.name):
  742. self.classifier.build([None, None, self.config.hidden_size])