modeling_tf_regnet.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608
  1. # coding=utf-8
  2. # Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """TensorFlow RegNet model."""
  16. from typing import Optional, Tuple, Union
  17. import tensorflow as tf
  18. from ...activations_tf import ACT2FN
  19. from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
  20. from ...modeling_tf_outputs import (
  21. TFBaseModelOutputWithNoAttention,
  22. TFBaseModelOutputWithPoolingAndNoAttention,
  23. TFSequenceClassifierOutput,
  24. )
  25. from ...modeling_tf_utils import (
  26. TFPreTrainedModel,
  27. TFSequenceClassificationLoss,
  28. keras,
  29. keras_serializable,
  30. unpack_inputs,
  31. )
  32. from ...tf_utils import shape_list
  33. from ...utils import logging
  34. from .configuration_regnet import RegNetConfig
  35. logger = logging.get_logger(__name__)
  36. # General docstring
  37. _CONFIG_FOR_DOC = "RegNetConfig"
  38. # Base docstring
  39. _CHECKPOINT_FOR_DOC = "facebook/regnet-y-040"
  40. _EXPECTED_OUTPUT_SHAPE = [1, 1088, 7, 7]
  41. # Image classification docstring
  42. _IMAGE_CLASS_CHECKPOINT = "facebook/regnet-y-040"
  43. _IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
  44. class TFRegNetConvLayer(keras.layers.Layer):
  45. def __init__(
  46. self,
  47. in_channels: int,
  48. out_channels: int,
  49. kernel_size: int = 3,
  50. stride: int = 1,
  51. groups: int = 1,
  52. activation: Optional[str] = "relu",
  53. **kwargs,
  54. ):
  55. super().__init__(**kwargs)
  56. # The padding and conv has been verified in
  57. # https://colab.research.google.com/gist/sayakpaul/854bc10eeaf21c9ee2119e0b9f3841a7/scratchpad.ipynb
  58. self.padding = keras.layers.ZeroPadding2D(padding=kernel_size // 2)
  59. self.convolution = keras.layers.Conv2D(
  60. filters=out_channels,
  61. kernel_size=kernel_size,
  62. strides=stride,
  63. padding="VALID",
  64. groups=groups,
  65. use_bias=False,
  66. name="convolution",
  67. )
  68. self.normalization = keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9, name="normalization")
  69. self.activation = ACT2FN[activation] if activation is not None else tf.identity
  70. self.in_channels = in_channels
  71. self.out_channels = out_channels
  72. def call(self, hidden_state):
  73. hidden_state = self.convolution(self.padding(hidden_state))
  74. hidden_state = self.normalization(hidden_state)
  75. hidden_state = self.activation(hidden_state)
  76. return hidden_state
  77. def build(self, input_shape=None):
  78. if self.built:
  79. return
  80. self.built = True
  81. if getattr(self, "convolution", None) is not None:
  82. with tf.name_scope(self.convolution.name):
  83. self.convolution.build([None, None, None, self.in_channels])
  84. if getattr(self, "normalization", None) is not None:
  85. with tf.name_scope(self.normalization.name):
  86. self.normalization.build([None, None, None, self.out_channels])
  87. class TFRegNetEmbeddings(keras.layers.Layer):
  88. """
  89. RegNet Embeddings (stem) composed of a single aggressive convolution.
  90. """
  91. def __init__(self, config: RegNetConfig, **kwargs):
  92. super().__init__(**kwargs)
  93. self.num_channels = config.num_channels
  94. self.embedder = TFRegNetConvLayer(
  95. in_channels=config.num_channels,
  96. out_channels=config.embedding_size,
  97. kernel_size=3,
  98. stride=2,
  99. activation=config.hidden_act,
  100. name="embedder",
  101. )
  102. def call(self, pixel_values):
  103. num_channels = shape_list(pixel_values)[1]
  104. if tf.executing_eagerly() and num_channels != self.num_channels:
  105. raise ValueError(
  106. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  107. )
  108. # When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format.
  109. # So change the input format from `NCHW` to `NHWC`.
  110. # shape = (batch_size, in_height, in_width, in_channels=num_channels)
  111. pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
  112. hidden_state = self.embedder(pixel_values)
  113. return hidden_state
  114. def build(self, input_shape=None):
  115. if self.built:
  116. return
  117. self.built = True
  118. if getattr(self, "embedder", None) is not None:
  119. with tf.name_scope(self.embedder.name):
  120. self.embedder.build(None)
  121. class TFRegNetShortCut(keras.layers.Layer):
  122. """
  123. RegNet shortcut, used to project the residual features to the correct size. If needed, it is also used to
  124. downsample the input using `stride=2`.
  125. """
  126. def __init__(self, in_channels: int, out_channels: int, stride: int = 2, **kwargs):
  127. super().__init__(**kwargs)
  128. self.convolution = keras.layers.Conv2D(
  129. filters=out_channels, kernel_size=1, strides=stride, use_bias=False, name="convolution"
  130. )
  131. self.normalization = keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9, name="normalization")
  132. self.in_channels = in_channels
  133. self.out_channels = out_channels
  134. def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
  135. return self.normalization(self.convolution(inputs), training=training)
  136. def build(self, input_shape=None):
  137. if self.built:
  138. return
  139. self.built = True
  140. if getattr(self, "convolution", None) is not None:
  141. with tf.name_scope(self.convolution.name):
  142. self.convolution.build([None, None, None, self.in_channels])
  143. if getattr(self, "normalization", None) is not None:
  144. with tf.name_scope(self.normalization.name):
  145. self.normalization.build([None, None, None, self.out_channels])
  146. class TFRegNetSELayer(keras.layers.Layer):
  147. """
  148. Squeeze and Excitation layer (SE) proposed in [Squeeze-and-Excitation Networks](https://arxiv.org/abs/1709.01507).
  149. """
  150. def __init__(self, in_channels: int, reduced_channels: int, **kwargs):
  151. super().__init__(**kwargs)
  152. self.pooler = keras.layers.GlobalAveragePooling2D(keepdims=True, name="pooler")
  153. self.attention = [
  154. keras.layers.Conv2D(filters=reduced_channels, kernel_size=1, activation="relu", name="attention.0"),
  155. keras.layers.Conv2D(filters=in_channels, kernel_size=1, activation="sigmoid", name="attention.2"),
  156. ]
  157. self.in_channels = in_channels
  158. self.reduced_channels = reduced_channels
  159. def call(self, hidden_state):
  160. # [batch_size, h, w, num_channels] -> [batch_size, 1, 1, num_channels]
  161. pooled = self.pooler(hidden_state)
  162. for layer_module in self.attention:
  163. pooled = layer_module(pooled)
  164. hidden_state = hidden_state * pooled
  165. return hidden_state
  166. def build(self, input_shape=None):
  167. if self.built:
  168. return
  169. self.built = True
  170. if getattr(self, "pooler", None) is not None:
  171. with tf.name_scope(self.pooler.name):
  172. self.pooler.build((None, None, None, None))
  173. if getattr(self, "attention", None) is not None:
  174. with tf.name_scope(self.attention[0].name):
  175. self.attention[0].build([None, None, None, self.in_channels])
  176. with tf.name_scope(self.attention[1].name):
  177. self.attention[1].build([None, None, None, self.reduced_channels])
  178. class TFRegNetXLayer(keras.layers.Layer):
  179. """
  180. RegNet's layer composed by three `3x3` convolutions, same as a ResNet bottleneck layer with reduction = 1.
  181. """
  182. def __init__(self, config: RegNetConfig, in_channels: int, out_channels: int, stride: int = 1, **kwargs):
  183. super().__init__(**kwargs)
  184. should_apply_shortcut = in_channels != out_channels or stride != 1
  185. groups = max(1, out_channels // config.groups_width)
  186. self.shortcut = (
  187. TFRegNetShortCut(in_channels, out_channels, stride=stride, name="shortcut")
  188. if should_apply_shortcut
  189. else keras.layers.Activation("linear", name="shortcut")
  190. )
  191. # `self.layers` instead of `self.layer` because that is a reserved argument.
  192. self.layers = [
  193. TFRegNetConvLayer(in_channels, out_channels, kernel_size=1, activation=config.hidden_act, name="layer.0"),
  194. TFRegNetConvLayer(
  195. out_channels, out_channels, stride=stride, groups=groups, activation=config.hidden_act, name="layer.1"
  196. ),
  197. TFRegNetConvLayer(out_channels, out_channels, kernel_size=1, activation=None, name="layer.2"),
  198. ]
  199. self.activation = ACT2FN[config.hidden_act]
  200. def call(self, hidden_state):
  201. residual = hidden_state
  202. for layer_module in self.layers:
  203. hidden_state = layer_module(hidden_state)
  204. residual = self.shortcut(residual)
  205. hidden_state += residual
  206. hidden_state = self.activation(hidden_state)
  207. return hidden_state
  208. def build(self, input_shape=None):
  209. if self.built:
  210. return
  211. self.built = True
  212. if getattr(self, "shortcut", None) is not None:
  213. with tf.name_scope(self.shortcut.name):
  214. self.shortcut.build(None)
  215. if getattr(self, "layers", None) is not None:
  216. for layer in self.layers:
  217. with tf.name_scope(layer.name):
  218. layer.build(None)
  219. class TFRegNetYLayer(keras.layers.Layer):
  220. """
  221. RegNet's Y layer: an X layer with Squeeze and Excitation.
  222. """
  223. def __init__(self, config: RegNetConfig, in_channels: int, out_channels: int, stride: int = 1, **kwargs):
  224. super().__init__(**kwargs)
  225. should_apply_shortcut = in_channels != out_channels or stride != 1
  226. groups = max(1, out_channels // config.groups_width)
  227. self.shortcut = (
  228. TFRegNetShortCut(in_channels, out_channels, stride=stride, name="shortcut")
  229. if should_apply_shortcut
  230. else keras.layers.Activation("linear", name="shortcut")
  231. )
  232. self.layers = [
  233. TFRegNetConvLayer(in_channels, out_channels, kernel_size=1, activation=config.hidden_act, name="layer.0"),
  234. TFRegNetConvLayer(
  235. out_channels, out_channels, stride=stride, groups=groups, activation=config.hidden_act, name="layer.1"
  236. ),
  237. TFRegNetSELayer(out_channels, reduced_channels=int(round(in_channels / 4)), name="layer.2"),
  238. TFRegNetConvLayer(out_channels, out_channels, kernel_size=1, activation=None, name="layer.3"),
  239. ]
  240. self.activation = ACT2FN[config.hidden_act]
  241. def call(self, hidden_state):
  242. residual = hidden_state
  243. for layer_module in self.layers:
  244. hidden_state = layer_module(hidden_state)
  245. residual = self.shortcut(residual)
  246. hidden_state += residual
  247. hidden_state = self.activation(hidden_state)
  248. return hidden_state
  249. def build(self, input_shape=None):
  250. if self.built:
  251. return
  252. self.built = True
  253. if getattr(self, "shortcut", None) is not None:
  254. with tf.name_scope(self.shortcut.name):
  255. self.shortcut.build(None)
  256. if getattr(self, "layers", None) is not None:
  257. for layer in self.layers:
  258. with tf.name_scope(layer.name):
  259. layer.build(None)
  260. class TFRegNetStage(keras.layers.Layer):
  261. """
  262. A RegNet stage composed by stacked layers.
  263. """
  264. def __init__(
  265. self, config: RegNetConfig, in_channels: int, out_channels: int, stride: int = 2, depth: int = 2, **kwargs
  266. ):
  267. super().__init__(**kwargs)
  268. layer = TFRegNetXLayer if config.layer_type == "x" else TFRegNetYLayer
  269. self.layers = [
  270. # downsampling is done in the first layer with stride of 2
  271. layer(config, in_channels, out_channels, stride=stride, name="layers.0"),
  272. *[layer(config, out_channels, out_channels, name=f"layers.{i+1}") for i in range(depth - 1)],
  273. ]
  274. def call(self, hidden_state):
  275. for layer_module in self.layers:
  276. hidden_state = layer_module(hidden_state)
  277. return hidden_state
  278. def build(self, input_shape=None):
  279. if self.built:
  280. return
  281. self.built = True
  282. if getattr(self, "layers", None) is not None:
  283. for layer in self.layers:
  284. with tf.name_scope(layer.name):
  285. layer.build(None)
  286. class TFRegNetEncoder(keras.layers.Layer):
  287. def __init__(self, config: RegNetConfig, **kwargs):
  288. super().__init__(**kwargs)
  289. self.stages = []
  290. # based on `downsample_in_first_stage`, the first layer of the first stage may or may not downsample the input
  291. self.stages.append(
  292. TFRegNetStage(
  293. config,
  294. config.embedding_size,
  295. config.hidden_sizes[0],
  296. stride=2 if config.downsample_in_first_stage else 1,
  297. depth=config.depths[0],
  298. name="stages.0",
  299. )
  300. )
  301. in_out_channels = zip(config.hidden_sizes, config.hidden_sizes[1:])
  302. for i, ((in_channels, out_channels), depth) in enumerate(zip(in_out_channels, config.depths[1:])):
  303. self.stages.append(TFRegNetStage(config, in_channels, out_channels, depth=depth, name=f"stages.{i+1}"))
  304. def call(
  305. self, hidden_state: tf.Tensor, output_hidden_states: bool = False, return_dict: bool = True
  306. ) -> TFBaseModelOutputWithNoAttention:
  307. hidden_states = () if output_hidden_states else None
  308. for stage_module in self.stages:
  309. if output_hidden_states:
  310. hidden_states = hidden_states + (hidden_state,)
  311. hidden_state = stage_module(hidden_state)
  312. if output_hidden_states:
  313. hidden_states = hidden_states + (hidden_state,)
  314. if not return_dict:
  315. return tuple(v for v in [hidden_state, hidden_states] if v is not None)
  316. return TFBaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=hidden_states)
  317. def build(self, input_shape=None):
  318. if self.built:
  319. return
  320. self.built = True
  321. for stage in self.stages:
  322. with tf.name_scope(stage.name):
  323. stage.build(None)
  324. @keras_serializable
  325. class TFRegNetMainLayer(keras.layers.Layer):
  326. config_class = RegNetConfig
  327. def __init__(self, config, **kwargs):
  328. super().__init__(**kwargs)
  329. self.config = config
  330. self.embedder = TFRegNetEmbeddings(config, name="embedder")
  331. self.encoder = TFRegNetEncoder(config, name="encoder")
  332. self.pooler = keras.layers.GlobalAveragePooling2D(keepdims=True, name="pooler")
  333. @unpack_inputs
  334. def call(
  335. self,
  336. pixel_values: tf.Tensor,
  337. output_hidden_states: Optional[bool] = None,
  338. return_dict: Optional[bool] = None,
  339. training: bool = False,
  340. ) -> TFBaseModelOutputWithPoolingAndNoAttention:
  341. output_hidden_states = (
  342. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  343. )
  344. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  345. embedding_output = self.embedder(pixel_values, training=training)
  346. encoder_outputs = self.encoder(
  347. embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training
  348. )
  349. last_hidden_state = encoder_outputs[0]
  350. pooled_output = self.pooler(last_hidden_state)
  351. # Change to NCHW output format have uniformity in the modules
  352. pooled_output = tf.transpose(pooled_output, perm=(0, 3, 1, 2))
  353. last_hidden_state = tf.transpose(last_hidden_state, perm=(0, 3, 1, 2))
  354. # Change the other hidden state outputs to NCHW as well
  355. if output_hidden_states:
  356. hidden_states = tuple([tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1]])
  357. if not return_dict:
  358. return (last_hidden_state, pooled_output) + encoder_outputs[1:]
  359. return TFBaseModelOutputWithPoolingAndNoAttention(
  360. last_hidden_state=last_hidden_state,
  361. pooler_output=pooled_output,
  362. hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states,
  363. )
  364. def build(self, input_shape=None):
  365. if self.built:
  366. return
  367. self.built = True
  368. if getattr(self, "embedder", None) is not None:
  369. with tf.name_scope(self.embedder.name):
  370. self.embedder.build(None)
  371. if getattr(self, "encoder", None) is not None:
  372. with tf.name_scope(self.encoder.name):
  373. self.encoder.build(None)
  374. if getattr(self, "pooler", None) is not None:
  375. with tf.name_scope(self.pooler.name):
  376. self.pooler.build((None, None, None, None))
  377. class TFRegNetPreTrainedModel(TFPreTrainedModel):
  378. """
  379. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  380. models.
  381. """
  382. config_class = RegNetConfig
  383. base_model_prefix = "regnet"
  384. main_input_name = "pixel_values"
  385. @property
  386. def input_signature(self):
  387. return {"pixel_values": tf.TensorSpec(shape=(None, self.config.num_channels, 224, 224), dtype=tf.float32)}
  388. REGNET_START_DOCSTRING = r"""
  389. This model is a Tensorflow
  390. [keras.layers.Layer](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer) sub-class. Use it as a
  391. regular Tensorflow Module and refer to the Tensorflow documentation for all matter related to general usage and
  392. behavior.
  393. Parameters:
  394. config ([`RegNetConfig`]): Model configuration class with all the parameters of the model.
  395. Initializing with a config file does not load the weights associated with the model, only the
  396. configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
  397. """
  398. REGNET_INPUTS_DOCSTRING = r"""
  399. Args:
  400. pixel_values (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
  401. Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
  402. [`ConveNextImageProcessor.__call__`] for details.
  403. output_hidden_states (`bool`, *optional*):
  404. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  405. more detail.
  406. return_dict (`bool`, *optional*):
  407. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  408. """
  409. @add_start_docstrings(
  410. "The bare RegNet model outputting raw features without any specific head on top.",
  411. REGNET_START_DOCSTRING,
  412. )
  413. class TFRegNetModel(TFRegNetPreTrainedModel):
  414. def __init__(self, config: RegNetConfig, *inputs, **kwargs):
  415. super().__init__(config, *inputs, **kwargs)
  416. self.regnet = TFRegNetMainLayer(config, name="regnet")
  417. @unpack_inputs
  418. @add_start_docstrings_to_model_forward(REGNET_INPUTS_DOCSTRING)
  419. @add_code_sample_docstrings(
  420. checkpoint=_CHECKPOINT_FOR_DOC,
  421. output_type=TFBaseModelOutputWithPoolingAndNoAttention,
  422. config_class=_CONFIG_FOR_DOC,
  423. modality="vision",
  424. expected_output=_EXPECTED_OUTPUT_SHAPE,
  425. )
  426. def call(
  427. self,
  428. pixel_values: tf.Tensor,
  429. output_hidden_states: Optional[bool] = None,
  430. return_dict: Optional[bool] = None,
  431. training: bool = False,
  432. ) -> Union[TFBaseModelOutputWithPoolingAndNoAttention, Tuple[tf.Tensor]]:
  433. output_hidden_states = (
  434. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  435. )
  436. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  437. outputs = self.regnet(
  438. pixel_values=pixel_values,
  439. output_hidden_states=output_hidden_states,
  440. return_dict=return_dict,
  441. training=training,
  442. )
  443. if not return_dict:
  444. return (outputs[0],) + outputs[1:]
  445. return TFBaseModelOutputWithPoolingAndNoAttention(
  446. last_hidden_state=outputs.last_hidden_state,
  447. pooler_output=outputs.pooler_output,
  448. hidden_states=outputs.hidden_states,
  449. )
  450. def build(self, input_shape=None):
  451. if self.built:
  452. return
  453. self.built = True
  454. if getattr(self, "regnet", None) is not None:
  455. with tf.name_scope(self.regnet.name):
  456. self.regnet.build(None)
  457. @add_start_docstrings(
  458. """
  459. RegNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
  460. ImageNet.
  461. """,
  462. REGNET_START_DOCSTRING,
  463. )
  464. class TFRegNetForImageClassification(TFRegNetPreTrainedModel, TFSequenceClassificationLoss):
  465. def __init__(self, config: RegNetConfig, *inputs, **kwargs):
  466. super().__init__(config, *inputs, **kwargs)
  467. self.num_labels = config.num_labels
  468. self.regnet = TFRegNetMainLayer(config, name="regnet")
  469. # classification head
  470. self.classifier = [
  471. keras.layers.Flatten(),
  472. keras.layers.Dense(config.num_labels, name="classifier.1") if config.num_labels > 0 else tf.identity,
  473. ]
  474. @unpack_inputs
  475. @add_start_docstrings_to_model_forward(REGNET_INPUTS_DOCSTRING)
  476. @add_code_sample_docstrings(
  477. checkpoint=_IMAGE_CLASS_CHECKPOINT,
  478. output_type=TFSequenceClassifierOutput,
  479. config_class=_CONFIG_FOR_DOC,
  480. expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
  481. )
  482. def call(
  483. self,
  484. pixel_values: Optional[tf.Tensor] = None,
  485. labels: Optional[tf.Tensor] = None,
  486. output_hidden_states: Optional[bool] = None,
  487. return_dict: Optional[bool] = None,
  488. training: bool = False,
  489. ) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
  490. r"""
  491. labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
  492. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  493. config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  494. """
  495. output_hidden_states = (
  496. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  497. )
  498. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  499. outputs = self.regnet(
  500. pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training
  501. )
  502. pooled_output = outputs.pooler_output if return_dict else outputs[1]
  503. flattened_output = self.classifier[0](pooled_output)
  504. logits = self.classifier[1](flattened_output)
  505. loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
  506. if not return_dict:
  507. output = (logits,) + outputs[2:]
  508. return ((loss,) + output) if loss is not None else output
  509. return TFSequenceClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
  510. def build(self, input_shape=None):
  511. if self.built:
  512. return
  513. self.built = True
  514. if getattr(self, "regnet", None) is not None:
  515. with tf.name_scope(self.regnet.name):
  516. self.regnet.build(None)
  517. if getattr(self, "classifier", None) is not None:
  518. with tf.name_scope(self.classifier[1].name):
  519. self.classifier[1].build([None, None, None, self.config.hidden_sizes[-1]])