modeling_tf_convnextv2.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680
  1. # coding=utf-8
  2. # Copyright 2023 Meta Platforms Inc. and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """TF 2.0 ConvNextV2 model."""
  16. from __future__ import annotations
  17. from typing import List, Optional, Tuple, Union
  18. import numpy as np
  19. import tensorflow as tf
  20. from ...activations_tf import get_tf_activation
  21. from ...modeling_tf_outputs import (
  22. TFBaseModelOutputWithNoAttention,
  23. TFBaseModelOutputWithPooling,
  24. TFBaseModelOutputWithPoolingAndNoAttention,
  25. TFImageClassifierOutputWithNoAttention,
  26. )
  27. from ...modeling_tf_utils import (
  28. TFModelInputType,
  29. TFPreTrainedModel,
  30. TFSequenceClassificationLoss,
  31. get_initializer,
  32. keras,
  33. keras_serializable,
  34. unpack_inputs,
  35. )
  36. from ...tf_utils import shape_list
  37. from ...utils import (
  38. add_code_sample_docstrings,
  39. add_start_docstrings,
  40. add_start_docstrings_to_model_forward,
  41. logging,
  42. )
  43. from .configuration_convnextv2 import ConvNextV2Config
  44. logger = logging.get_logger(__name__)
  45. # General docstring
  46. _CONFIG_FOR_DOC = "ConvNextV2Config"
  47. # Base docstring
  48. _CHECKPOINT_FOR_DOC = "facebook/convnextv2-tiny-1k-224"
  49. _EXPECTED_OUTPUT_SHAPE = [1, 768, 7, 7]
  50. # Image classification docstring
  51. _IMAGE_CLASS_CHECKPOINT = "facebook/convnextv2-tiny-1k-224"
  52. _IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
  53. # Copied from transformers.models.convnext.modeling_tf_convnext.TFConvNextDropPath with ConvNext->ConvNextV2
  54. class TFConvNextV2DropPath(keras.layers.Layer):
  55. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  56. References:
  57. (1) github.com:rwightman/pytorch-image-models
  58. """
  59. def __init__(self, drop_path: float, **kwargs):
  60. super().__init__(**kwargs)
  61. self.drop_path = drop_path
  62. def call(self, x: tf.Tensor, training=None):
  63. if training:
  64. keep_prob = 1 - self.drop_path
  65. shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)
  66. random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
  67. random_tensor = tf.floor(random_tensor)
  68. return (x / keep_prob) * random_tensor
  69. return x
  70. class TFConvNextV2GRN(keras.layers.Layer):
  71. """GRN (Global Response Normalization) layer"""
  72. def __init__(self, config: ConvNextV2Config, dim: int, **kwargs):
  73. super().__init__(**kwargs)
  74. self.dim = dim
  75. def build(self, input_shape: tf.TensorShape = None):
  76. # PT's `nn.Parameters` must be mapped to a TF layer weight to inherit the same name hierarchy (and vice-versa)
  77. self.weight = self.add_weight(
  78. name="weight",
  79. shape=(1, 1, 1, self.dim),
  80. initializer=keras.initializers.Zeros(),
  81. )
  82. self.bias = self.add_weight(
  83. name="bias",
  84. shape=(1, 1, 1, self.dim),
  85. initializer=keras.initializers.Zeros(),
  86. )
  87. return super().build(input_shape)
  88. def call(self, hidden_states: tf.Tensor):
  89. global_features = tf.norm(hidden_states, ord="euclidean", axis=(1, 2), keepdims=True)
  90. norm_features = global_features / (tf.reduce_mean(global_features, axis=-1, keepdims=True) + 1e-6)
  91. hidden_states = self.weight * (hidden_states * norm_features) + self.bias + hidden_states
  92. return hidden_states
  93. # Copied from transformers.models.convnext.modeling_tf_convnext.TFConvNextEmbeddings with ConvNext->ConvNextV2
  94. class TFConvNextV2Embeddings(keras.layers.Layer):
  95. """This class is comparable to (and inspired by) the SwinEmbeddings class
  96. found in src/transformers/models/swin/modeling_swin.py.
  97. """
  98. def __init__(self, config: ConvNextV2Config, **kwargs):
  99. super().__init__(**kwargs)
  100. self.patch_embeddings = keras.layers.Conv2D(
  101. filters=config.hidden_sizes[0],
  102. kernel_size=config.patch_size,
  103. strides=config.patch_size,
  104. name="patch_embeddings",
  105. kernel_initializer=get_initializer(config.initializer_range),
  106. bias_initializer=keras.initializers.Zeros(),
  107. )
  108. self.layernorm = keras.layers.LayerNormalization(epsilon=1e-6, name="layernorm")
  109. self.num_channels = config.num_channels
  110. self.config = config
  111. def call(self, pixel_values):
  112. if isinstance(pixel_values, dict):
  113. pixel_values = pixel_values["pixel_values"]
  114. tf.debugging.assert_equal(
  115. shape_list(pixel_values)[1],
  116. self.num_channels,
  117. message="Make sure that the channel dimension of the pixel values match with the one set in the configuration.",
  118. )
  119. # When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format.
  120. # So change the input format from `NCHW` to `NHWC`.
  121. # shape = (batch_size, in_height, in_width, in_channels)
  122. pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
  123. embeddings = self.patch_embeddings(pixel_values)
  124. embeddings = self.layernorm(embeddings)
  125. return embeddings
  126. def build(self, input_shape=None):
  127. if self.built:
  128. return
  129. self.built = True
  130. if getattr(self, "patch_embeddings", None) is not None:
  131. with tf.name_scope(self.patch_embeddings.name):
  132. self.patch_embeddings.build([None, None, None, self.config.num_channels])
  133. if getattr(self, "layernorm", None) is not None:
  134. with tf.name_scope(self.layernorm.name):
  135. self.layernorm.build([None, None, None, self.config.hidden_sizes[0]])
  136. class TFConvNextV2Layer(keras.layers.Layer):
  137. """This corresponds to the `Block` class in the original implementation.
  138. There are two equivalent implementations: [DwConv, LayerNorm (channels_first), Conv, GELU,1x1 Conv]; all in (N, C,
  139. H, W) (2) [DwConv, Permute to (N, H, W, C), LayerNorm (channels_last), Linear, GELU, Linear]; Permute back
  140. The authors used (2) as they find it slightly faster in PyTorch. Since we already permuted the inputs to follow
  141. NHWC ordering, we can just apply the operations straight-away without the permutation.
  142. Args:
  143. config (`ConvNextV2Config`):
  144. Model configuration class.
  145. dim (`int`):
  146. Number of input channels.
  147. drop_path (`float`, *optional*, defaults to 0.0):
  148. Stochastic depth rate.
  149. """
  150. def __init__(self, config: ConvNextV2Config, dim: int, drop_path: float = 0.0, **kwargs):
  151. super().__init__(**kwargs)
  152. self.dim = dim
  153. self.config = config
  154. self.dwconv = keras.layers.Conv2D(
  155. filters=dim,
  156. kernel_size=7,
  157. padding="same",
  158. groups=dim,
  159. kernel_initializer=get_initializer(config.initializer_range),
  160. bias_initializer=keras.initializers.Zeros(),
  161. name="dwconv",
  162. ) # depthwise conv
  163. self.layernorm = keras.layers.LayerNormalization(
  164. epsilon=1e-6,
  165. name="layernorm",
  166. )
  167. self.pwconv1 = keras.layers.Dense(
  168. units=4 * dim,
  169. kernel_initializer=get_initializer(config.initializer_range),
  170. bias_initializer=keras.initializers.Zeros(),
  171. name="pwconv1",
  172. ) # pointwise/1x1 convs, implemented with linear layers
  173. self.act = get_tf_activation(config.hidden_act)
  174. self.grn = TFConvNextV2GRN(config, 4 * dim, dtype=tf.float32, name="grn")
  175. self.pwconv2 = keras.layers.Dense(
  176. units=dim,
  177. kernel_initializer=get_initializer(config.initializer_range),
  178. bias_initializer=keras.initializers.Zeros(),
  179. name="pwconv2",
  180. )
  181. # Using `layers.Activation` instead of `tf.identity` to better control `training`
  182. # behaviour.
  183. self.drop_path = (
  184. TFConvNextV2DropPath(drop_path, name="drop_path")
  185. if drop_path > 0.0
  186. else keras.layers.Activation("linear", name="drop_path")
  187. )
  188. def call(self, hidden_states, training=False):
  189. input = hidden_states
  190. x = self.dwconv(hidden_states)
  191. x = self.layernorm(x)
  192. x = self.pwconv1(x)
  193. x = self.act(x)
  194. x = self.grn(x)
  195. x = self.pwconv2(x)
  196. x = self.drop_path(x, training=training)
  197. x = input + x
  198. return x
  199. def build(self, input_shape=None):
  200. if self.built:
  201. return
  202. self.built = True
  203. if getattr(self, "dwconv", None) is not None:
  204. with tf.name_scope(self.dwconv.name):
  205. self.dwconv.build([None, None, None, self.dim])
  206. if getattr(self, "layernorm", None) is not None:
  207. with tf.name_scope(self.layernorm.name):
  208. self.layernorm.build([None, None, None, self.dim])
  209. if getattr(self, "pwconv1", None) is not None:
  210. with tf.name_scope(self.pwconv1.name):
  211. self.pwconv1.build([None, None, self.dim])
  212. if getattr(self, "grn", None) is not None:
  213. with tf.name_scope(self.grn.name):
  214. self.grn.build(None)
  215. if getattr(self, "pwconv2", None) is not None:
  216. with tf.name_scope(self.pwconv2.name):
  217. self.pwconv2.build([None, None, 4 * self.dim])
  218. if getattr(self, "drop_path", None) is not None:
  219. with tf.name_scope(self.drop_path.name):
  220. self.drop_path.build(None)
  221. # Copied from transformers.models.convnext.modeling_tf_convnext.TFConvNextStage with ConvNext->ConvNextV2
  222. class TFConvNextV2Stage(keras.layers.Layer):
  223. """ConvNextV2 stage, consisting of an optional downsampling layer + multiple residual blocks.
  224. Args:
  225. config (`ConvNextV2V2Config`):
  226. Model configuration class.
  227. in_channels (`int`):
  228. Number of input channels.
  229. out_channels (`int`):
  230. Number of output channels.
  231. depth (`int`):
  232. Number of residual blocks.
  233. drop_path_rates(`List[float]`):
  234. Stochastic depth rates for each layer.
  235. """
  236. def __init__(
  237. self,
  238. config: ConvNextV2Config,
  239. in_channels: int,
  240. out_channels: int,
  241. kernel_size: int = 2,
  242. stride: int = 2,
  243. depth: int = 2,
  244. drop_path_rates: Optional[List[float]] = None,
  245. **kwargs,
  246. ):
  247. super().__init__(**kwargs)
  248. if in_channels != out_channels or stride > 1:
  249. self.downsampling_layer = [
  250. keras.layers.LayerNormalization(
  251. epsilon=1e-6,
  252. name="downsampling_layer.0",
  253. ),
  254. # Inputs to this layer will follow NHWC format since we
  255. # transposed the inputs from NCHW to NHWC in the `TFConvNextV2Embeddings`
  256. # layer. All the outputs throughout the model will be in NHWC
  257. # from this point on until the output where we again change to
  258. # NCHW.
  259. keras.layers.Conv2D(
  260. filters=out_channels,
  261. kernel_size=kernel_size,
  262. strides=stride,
  263. kernel_initializer=get_initializer(config.initializer_range),
  264. bias_initializer=keras.initializers.Zeros(),
  265. name="downsampling_layer.1",
  266. ),
  267. ]
  268. else:
  269. self.downsampling_layer = [tf.identity]
  270. drop_path_rates = drop_path_rates or [0.0] * depth
  271. self.layers = [
  272. TFConvNextV2Layer(
  273. config,
  274. dim=out_channels,
  275. drop_path=drop_path_rates[j],
  276. name=f"layers.{j}",
  277. )
  278. for j in range(depth)
  279. ]
  280. self.in_channels = in_channels
  281. self.out_channels = out_channels
  282. self.stride = stride
  283. def call(self, hidden_states):
  284. for layer in self.downsampling_layer:
  285. hidden_states = layer(hidden_states)
  286. for layer in self.layers:
  287. hidden_states = layer(hidden_states)
  288. return hidden_states
  289. def build(self, input_shape=None):
  290. if self.built:
  291. return
  292. self.built = True
  293. if getattr(self, "layers", None) is not None:
  294. for layer in self.layers:
  295. with tf.name_scope(layer.name):
  296. layer.build(None)
  297. if self.in_channels != self.out_channels or self.stride > 1:
  298. with tf.name_scope(self.downsampling_layer[0].name):
  299. self.downsampling_layer[0].build([None, None, None, self.in_channels])
  300. with tf.name_scope(self.downsampling_layer[1].name):
  301. self.downsampling_layer[1].build([None, None, None, self.in_channels])
  302. class TFConvNextV2Encoder(keras.layers.Layer):
  303. def __init__(self, config: ConvNextV2Config, **kwargs):
  304. super().__init__(**kwargs)
  305. self.stages = []
  306. drop_path_rates = tf.linspace(0.0, config.drop_path_rate, sum(config.depths))
  307. drop_path_rates = tf.split(drop_path_rates, config.depths)
  308. drop_path_rates = [x.numpy().tolist() for x in drop_path_rates]
  309. prev_chs = config.hidden_sizes[0]
  310. for i in range(config.num_stages):
  311. out_chs = config.hidden_sizes[i]
  312. stage = TFConvNextV2Stage(
  313. config,
  314. in_channels=prev_chs,
  315. out_channels=out_chs,
  316. stride=2 if i > 0 else 1,
  317. depth=config.depths[i],
  318. drop_path_rates=drop_path_rates[i],
  319. name=f"stages.{i}",
  320. )
  321. self.stages.append(stage)
  322. prev_chs = out_chs
  323. def call(
  324. self,
  325. hidden_states: tf.Tensor,
  326. output_hidden_states: Optional[bool] = False,
  327. return_dict: Optional[bool] = True,
  328. ) -> Union[Tuple, TFBaseModelOutputWithNoAttention]:
  329. all_hidden_states = () if output_hidden_states else None
  330. for i, layer_module in enumerate(self.stages):
  331. if output_hidden_states:
  332. all_hidden_states = all_hidden_states + (hidden_states,)
  333. hidden_states = layer_module(hidden_states)
  334. if output_hidden_states:
  335. all_hidden_states = all_hidden_states + (hidden_states,)
  336. if not return_dict:
  337. return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
  338. return TFBaseModelOutputWithNoAttention(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
  339. def build(self, input_shape=None):
  340. for stage in self.stages:
  341. with tf.name_scope(stage.name):
  342. stage.build(None)
  343. @keras_serializable
  344. class TFConvNextV2MainLayer(keras.layers.Layer):
  345. config_class = ConvNextV2Config
  346. def __init__(self, config: ConvNextV2Config, **kwargs):
  347. super().__init__(**kwargs)
  348. self.config = config
  349. self.embeddings = TFConvNextV2Embeddings(config, name="embeddings")
  350. self.encoder = TFConvNextV2Encoder(config, name="encoder")
  351. self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
  352. # We are setting the `data_format` like so because from here on we will revert to the
  353. # NCHW output format
  354. self.pooler = keras.layers.GlobalAvgPool2D(data_format="channels_last")
  355. @unpack_inputs
  356. def call(
  357. self,
  358. pixel_values: TFModelInputType | None = None,
  359. output_hidden_states: Optional[bool] = None,
  360. return_dict: Optional[bool] = None,
  361. training: bool = False,
  362. ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
  363. output_hidden_states = (
  364. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  365. )
  366. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  367. if pixel_values is None:
  368. raise ValueError("You have to specify pixel_values")
  369. embedding_output = self.embeddings(pixel_values, training=training)
  370. encoder_outputs = self.encoder(
  371. embedding_output,
  372. output_hidden_states=output_hidden_states,
  373. return_dict=return_dict,
  374. training=training,
  375. )
  376. last_hidden_state = encoder_outputs[0]
  377. # Change to NCHW output format have uniformity in the modules
  378. pooled_output = self.pooler(last_hidden_state)
  379. last_hidden_state = tf.transpose(last_hidden_state, perm=(0, 3, 1, 2))
  380. pooled_output = self.layernorm(pooled_output)
  381. # Change the other hidden state outputs to NCHW as well
  382. if output_hidden_states:
  383. hidden_states = tuple([tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1]])
  384. if not return_dict:
  385. hidden_states = hidden_states if output_hidden_states else ()
  386. return (last_hidden_state, pooled_output) + hidden_states
  387. return TFBaseModelOutputWithPoolingAndNoAttention(
  388. last_hidden_state=last_hidden_state,
  389. pooler_output=pooled_output,
  390. hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states,
  391. )
  392. def build(self, input_shape=None):
  393. if self.built:
  394. return
  395. self.built = True
  396. if getattr(self, "embeddings", None) is not None:
  397. with tf.name_scope(self.embeddings.name):
  398. self.embeddings.build(None)
  399. if getattr(self, "encoder", None) is not None:
  400. with tf.name_scope(self.encoder.name):
  401. self.encoder.build(None)
  402. if getattr(self, "layernorm", None) is not None:
  403. with tf.name_scope(self.layernorm.name):
  404. self.layernorm.build([None, self.config.hidden_sizes[-1]])
  405. class TFConvNextV2PreTrainedModel(TFPreTrainedModel):
  406. """
  407. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  408. models.
  409. """
  410. config_class = ConvNextV2Config
  411. base_model_prefix = "convnextv2"
  412. main_input_name = "pixel_values"
  413. CONVNEXTV2_START_DOCSTRING = r"""
  414. This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
  415. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  416. etc.)
  417. This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
  418. as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
  419. behavior.
  420. <Tip>
  421. TensorFlow models and layers in `transformers` accept two formats as input:
  422. - having all inputs as keyword arguments (like PyTorch models), or
  423. - having all inputs as a list, tuple or dict in the first positional argument.
  424. The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
  425. and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
  426. pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
  427. format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
  428. the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
  429. positional argument:
  430. - a single Tensor with `pixel_values` only and nothing else: `model(pixel_values)`
  431. - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
  432. `model([pixel_values, attention_mask])` or `model([pixel_values, attention_mask, token_type_ids])`
  433. - a dictionary with one or several input Tensors associated to the input names given in the docstring:
  434. `model({"pixel_values": pixel_values, "token_type_ids": token_type_ids})`
  435. Note that when creating models and layers with
  436. [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
  437. about any of this, as you can just pass inputs like you would to any other Python function!
  438. </Tip>
  439. Parameters:
  440. config ([`ConvNextV2Config`]): Model configuration class with all the parameters of the model.
  441. Initializing with a config file does not load the weights associated with the model, only the
  442. configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
  443. """
  444. CONVNEXTV2_INPUTS_DOCSTRING = r"""
  445. Args:
  446. pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]`, `Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
  447. Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
  448. [`ConvNextImageProcessor.__call__`] for details.
  449. output_hidden_states (`bool`, *optional*):
  450. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  451. more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
  452. used instead.
  453. return_dict (`bool`, *optional*):
  454. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
  455. eager mode, in graph mode the value will always be set to `True`.
  456. """
  457. @add_start_docstrings(
  458. "The bare ConvNextV2 model outputting raw features without any specific head on top.",
  459. CONVNEXTV2_START_DOCSTRING,
  460. )
  461. class TFConvNextV2Model(TFConvNextV2PreTrainedModel):
  462. def __init__(self, config: ConvNextV2Config, *inputs, **kwargs):
  463. super().__init__(config, *inputs, **kwargs)
  464. self.convnextv2 = TFConvNextV2MainLayer(config, name="convnextv2")
  465. @unpack_inputs
  466. @add_start_docstrings_to_model_forward(CONVNEXTV2_INPUTS_DOCSTRING)
  467. @add_code_sample_docstrings(
  468. checkpoint=_CHECKPOINT_FOR_DOC,
  469. output_type=TFBaseModelOutputWithPoolingAndNoAttention,
  470. config_class=_CONFIG_FOR_DOC,
  471. modality="vision",
  472. expected_output=_EXPECTED_OUTPUT_SHAPE,
  473. )
  474. def call(
  475. self,
  476. pixel_values: TFModelInputType | None = None,
  477. output_hidden_states: Optional[bool] = None,
  478. return_dict: Optional[bool] = None,
  479. training: bool = False,
  480. ) -> Union[TFBaseModelOutputWithPoolingAndNoAttention, Tuple[tf.Tensor]]:
  481. output_hidden_states = (
  482. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  483. )
  484. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  485. if pixel_values is None:
  486. raise ValueError("You have to specify pixel_values")
  487. outputs = self.convnextv2(
  488. pixel_values=pixel_values,
  489. output_hidden_states=output_hidden_states,
  490. return_dict=return_dict,
  491. training=training,
  492. )
  493. if not return_dict:
  494. return outputs[:]
  495. return TFBaseModelOutputWithPoolingAndNoAttention(
  496. last_hidden_state=outputs.last_hidden_state,
  497. pooler_output=outputs.pooler_output,
  498. hidden_states=outputs.hidden_states,
  499. )
  500. def build(self, input_shape=None):
  501. if self.built:
  502. return
  503. self.built = True
  504. if getattr(self, "convnextv2", None) is not None:
  505. with tf.name_scope(self.convnextv2.name):
  506. self.convnextv2.build(None)
  507. @add_start_docstrings(
  508. """
  509. ConvNextV2 Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
  510. ImageNet.
  511. """,
  512. CONVNEXTV2_START_DOCSTRING,
  513. )
  514. class TFConvNextV2ForImageClassification(TFConvNextV2PreTrainedModel, TFSequenceClassificationLoss):
  515. def __init__(self, config: ConvNextV2Config, *inputs, **kwargs):
  516. super().__init__(config, *inputs, **kwargs)
  517. self.num_labels = config.num_labels
  518. self.convnextv2 = TFConvNextV2MainLayer(config, name="convnextv2")
  519. # Classifier head
  520. self.classifier = keras.layers.Dense(
  521. units=config.num_labels,
  522. kernel_initializer=get_initializer(config.initializer_range),
  523. bias_initializer=keras.initializers.Zeros(),
  524. name="classifier",
  525. )
  526. @unpack_inputs
  527. @add_start_docstrings_to_model_forward(CONVNEXTV2_INPUTS_DOCSTRING)
  528. @add_code_sample_docstrings(
  529. checkpoint=_IMAGE_CLASS_CHECKPOINT,
  530. output_type=TFImageClassifierOutputWithNoAttention,
  531. config_class=_CONFIG_FOR_DOC,
  532. expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
  533. )
  534. def call(
  535. self,
  536. pixel_values: TFModelInputType | None = None,
  537. output_hidden_states: Optional[bool] = None,
  538. return_dict: Optional[bool] = None,
  539. labels: np.ndarray | tf.Tensor | None = None,
  540. training: Optional[bool] = False,
  541. ) -> Union[TFImageClassifierOutputWithNoAttention, Tuple[tf.Tensor]]:
  542. r"""
  543. labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*):
  544. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  545. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  546. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  547. """
  548. output_hidden_states = (
  549. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  550. )
  551. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  552. if pixel_values is None:
  553. raise ValueError("You have to specify pixel_values")
  554. outputs = self.convnextv2(
  555. pixel_values,
  556. output_hidden_states=output_hidden_states,
  557. return_dict=return_dict,
  558. training=training,
  559. )
  560. pooled_output = outputs.pooler_output if return_dict else outputs[1]
  561. logits = self.classifier(pooled_output)
  562. loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
  563. if not return_dict:
  564. output = (logits,) + outputs[2:]
  565. return ((loss,) + output) if loss is not None else output
  566. return TFImageClassifierOutputWithNoAttention(
  567. loss=loss,
  568. logits=logits,
  569. hidden_states=outputs.hidden_states,
  570. )
  571. def build(self, input_shape=None):
  572. if self.built:
  573. return
  574. self.built = True
  575. if getattr(self, "convnextv2", None) is not None:
  576. with tf.name_scope(self.convnextv2.name):
  577. self.convnextv2.build(None)
  578. if getattr(self, "classifier", None) is not None:
  579. with tf.name_scope(self.classifier.name):
  580. self.classifier.build([None, None, self.config.hidden_sizes[-1]])