modeling_efficientnet.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644
  1. # coding=utf-8
  2. # Copyright 2023 Google Research, Inc. and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch EfficientNet model."""
  16. import math
  17. from typing import Optional, Tuple, Union
  18. import torch
  19. import torch.utils.checkpoint
  20. from torch import nn
  21. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  22. from ...activations import ACT2FN
  23. from ...modeling_outputs import (
  24. BaseModelOutputWithNoAttention,
  25. BaseModelOutputWithPoolingAndNoAttention,
  26. ImageClassifierOutputWithNoAttention,
  27. )
  28. from ...modeling_utils import PreTrainedModel
  29. from ...utils import (
  30. add_code_sample_docstrings,
  31. add_start_docstrings,
  32. add_start_docstrings_to_model_forward,
  33. logging,
  34. )
  35. from .configuration_efficientnet import EfficientNetConfig
  36. logger = logging.get_logger(__name__)
  37. # General docstring
  38. _CONFIG_FOR_DOC = "EfficientNetConfig"
  39. # Base docstring
  40. _CHECKPOINT_FOR_DOC = "google/efficientnet-b7"
  41. _EXPECTED_OUTPUT_SHAPE = [1, 768, 7, 7]
  42. # Image classification docstring
  43. _IMAGE_CLASS_CHECKPOINT = "google/efficientnet-b7"
  44. _IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
  45. EFFICIENTNET_START_DOCSTRING = r"""
  46. This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
  47. as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
  48. behavior.
  49. Parameters:
  50. config ([`EfficientNetConfig`]): Model configuration class with all the parameters of the model.
  51. Initializing with a config file does not load the weights associated with the model, only the
  52. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  53. """
  54. EFFICIENTNET_INPUTS_DOCSTRING = r"""
  55. Args:
  56. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  57. Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
  58. [`AutoImageProcessor.__call__`] for details.
  59. output_hidden_states (`bool`, *optional*):
  60. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  61. more detail.
  62. return_dict (`bool`, *optional*):
  63. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  64. """
  65. def round_filters(config: EfficientNetConfig, num_channels: int):
  66. r"""
  67. Round number of filters based on depth multiplier.
  68. """
  69. divisor = config.depth_divisor
  70. num_channels *= config.width_coefficient
  71. new_dim = max(divisor, int(num_channels + divisor / 2) // divisor * divisor)
  72. # Make sure that round down does not go down by more than 10%.
  73. if new_dim < 0.9 * num_channels:
  74. new_dim += divisor
  75. return int(new_dim)
  76. def correct_pad(kernel_size: Union[int, Tuple], adjust: bool = True):
  77. r"""
  78. Utility function to get the tuple padding value for the depthwise convolution.
  79. Args:
  80. kernel_size (`int` or `tuple`):
  81. Kernel size of the convolution layers.
  82. adjust (`bool`, *optional*, defaults to `True`):
  83. Adjusts padding value to apply to right and bottom sides of the input.
  84. """
  85. if isinstance(kernel_size, int):
  86. kernel_size = (kernel_size, kernel_size)
  87. correct = (kernel_size[0] // 2, kernel_size[1] // 2)
  88. if adjust:
  89. return (correct[1] - 1, correct[1], correct[0] - 1, correct[0])
  90. else:
  91. return (correct[1], correct[1], correct[0], correct[0])
  92. class EfficientNetEmbeddings(nn.Module):
  93. r"""
  94. A module that corresponds to the stem module of the original work.
  95. """
  96. def __init__(self, config: EfficientNetConfig):
  97. super().__init__()
  98. self.out_dim = round_filters(config, 32)
  99. self.padding = nn.ZeroPad2d(padding=(0, 1, 0, 1))
  100. self.convolution = nn.Conv2d(
  101. config.num_channels, self.out_dim, kernel_size=3, stride=2, padding="valid", bias=False
  102. )
  103. self.batchnorm = nn.BatchNorm2d(self.out_dim, eps=config.batch_norm_eps, momentum=config.batch_norm_momentum)
  104. self.activation = ACT2FN[config.hidden_act]
  105. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  106. features = self.padding(pixel_values)
  107. features = self.convolution(features)
  108. features = self.batchnorm(features)
  109. features = self.activation(features)
  110. return features
  111. class EfficientNetDepthwiseConv2d(nn.Conv2d):
  112. def __init__(
  113. self,
  114. in_channels,
  115. depth_multiplier=1,
  116. kernel_size=3,
  117. stride=1,
  118. padding=0,
  119. dilation=1,
  120. bias=True,
  121. padding_mode="zeros",
  122. ):
  123. out_channels = in_channels * depth_multiplier
  124. super().__init__(
  125. in_channels=in_channels,
  126. out_channels=out_channels,
  127. kernel_size=kernel_size,
  128. stride=stride,
  129. padding=padding,
  130. dilation=dilation,
  131. groups=in_channels,
  132. bias=bias,
  133. padding_mode=padding_mode,
  134. )
  135. class EfficientNetExpansionLayer(nn.Module):
  136. r"""
  137. This corresponds to the expansion phase of each block in the original implementation.
  138. """
  139. def __init__(self, config: EfficientNetConfig, in_dim: int, out_dim: int, stride: int):
  140. super().__init__()
  141. self.expand_conv = nn.Conv2d(
  142. in_channels=in_dim,
  143. out_channels=out_dim,
  144. kernel_size=1,
  145. padding="same",
  146. bias=False,
  147. )
  148. self.expand_bn = nn.BatchNorm2d(num_features=out_dim, eps=config.batch_norm_eps)
  149. self.expand_act = ACT2FN[config.hidden_act]
  150. def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
  151. # Expand phase
  152. hidden_states = self.expand_conv(hidden_states)
  153. hidden_states = self.expand_bn(hidden_states)
  154. hidden_states = self.expand_act(hidden_states)
  155. return hidden_states
  156. class EfficientNetDepthwiseLayer(nn.Module):
  157. r"""
  158. This corresponds to the depthwise convolution phase of each block in the original implementation.
  159. """
  160. def __init__(
  161. self,
  162. config: EfficientNetConfig,
  163. in_dim: int,
  164. stride: int,
  165. kernel_size: int,
  166. adjust_padding: bool,
  167. ):
  168. super().__init__()
  169. self.stride = stride
  170. conv_pad = "valid" if self.stride == 2 else "same"
  171. padding = correct_pad(kernel_size, adjust=adjust_padding)
  172. self.depthwise_conv_pad = nn.ZeroPad2d(padding=padding)
  173. self.depthwise_conv = EfficientNetDepthwiseConv2d(
  174. in_dim, kernel_size=kernel_size, stride=stride, padding=conv_pad, bias=False
  175. )
  176. self.depthwise_norm = nn.BatchNorm2d(
  177. num_features=in_dim, eps=config.batch_norm_eps, momentum=config.batch_norm_momentum
  178. )
  179. self.depthwise_act = ACT2FN[config.hidden_act]
  180. def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
  181. # Depthwise convolution
  182. if self.stride == 2:
  183. hidden_states = self.depthwise_conv_pad(hidden_states)
  184. hidden_states = self.depthwise_conv(hidden_states)
  185. hidden_states = self.depthwise_norm(hidden_states)
  186. hidden_states = self.depthwise_act(hidden_states)
  187. return hidden_states
  188. class EfficientNetSqueezeExciteLayer(nn.Module):
  189. r"""
  190. This corresponds to the Squeeze and Excitement phase of each block in the original implementation.
  191. """
  192. def __init__(self, config: EfficientNetConfig, in_dim: int, expand_dim: int, expand: bool = False):
  193. super().__init__()
  194. self.dim = expand_dim if expand else in_dim
  195. self.dim_se = max(1, int(in_dim * config.squeeze_expansion_ratio))
  196. self.squeeze = nn.AdaptiveAvgPool2d(output_size=1)
  197. self.reduce = nn.Conv2d(
  198. in_channels=self.dim,
  199. out_channels=self.dim_se,
  200. kernel_size=1,
  201. padding="same",
  202. )
  203. self.expand = nn.Conv2d(
  204. in_channels=self.dim_se,
  205. out_channels=self.dim,
  206. kernel_size=1,
  207. padding="same",
  208. )
  209. self.act_reduce = ACT2FN[config.hidden_act]
  210. self.act_expand = nn.Sigmoid()
  211. def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
  212. inputs = hidden_states
  213. hidden_states = self.squeeze(hidden_states)
  214. hidden_states = self.reduce(hidden_states)
  215. hidden_states = self.act_reduce(hidden_states)
  216. hidden_states = self.expand(hidden_states)
  217. hidden_states = self.act_expand(hidden_states)
  218. hidden_states = torch.mul(inputs, hidden_states)
  219. return hidden_states
  220. class EfficientNetFinalBlockLayer(nn.Module):
  221. r"""
  222. This corresponds to the final phase of each block in the original implementation.
  223. """
  224. def __init__(
  225. self, config: EfficientNetConfig, in_dim: int, out_dim: int, stride: int, drop_rate: float, id_skip: bool
  226. ):
  227. super().__init__()
  228. self.apply_dropout = stride == 1 and not id_skip
  229. self.project_conv = nn.Conv2d(
  230. in_channels=in_dim,
  231. out_channels=out_dim,
  232. kernel_size=1,
  233. padding="same",
  234. bias=False,
  235. )
  236. self.project_bn = nn.BatchNorm2d(
  237. num_features=out_dim, eps=config.batch_norm_eps, momentum=config.batch_norm_momentum
  238. )
  239. self.dropout = nn.Dropout(p=drop_rate)
  240. def forward(self, embeddings: torch.FloatTensor, hidden_states: torch.FloatTensor) -> torch.Tensor:
  241. hidden_states = self.project_conv(hidden_states)
  242. hidden_states = self.project_bn(hidden_states)
  243. if self.apply_dropout:
  244. hidden_states = self.dropout(hidden_states)
  245. hidden_states = hidden_states + embeddings
  246. return hidden_states
  247. class EfficientNetBlock(nn.Module):
  248. r"""
  249. This corresponds to the expansion and depthwise convolution phase of each block in the original implementation.
  250. Args:
  251. config ([`EfficientNetConfig`]):
  252. Model configuration class.
  253. in_dim (`int`):
  254. Number of input channels.
  255. out_dim (`int`):
  256. Number of output channels.
  257. stride (`int`):
  258. Stride size to be used in convolution layers.
  259. expand_ratio (`int`):
  260. Expand ratio to set the output dimensions for the expansion and squeeze-excite layers.
  261. kernel_size (`int`):
  262. Kernel size for the depthwise convolution layer.
  263. drop_rate (`float`):
  264. Dropout rate to be used in the final phase of each block.
  265. id_skip (`bool`):
  266. Whether to apply dropout and sum the final hidden states with the input embeddings during the final phase
  267. of each block. Set to `True` for the first block of each stage.
  268. adjust_padding (`bool`):
  269. Whether to apply padding to only right and bottom side of the input kernel before the depthwise convolution
  270. operation, set to `True` for inputs with odd input sizes.
  271. """
  272. def __init__(
  273. self,
  274. config: EfficientNetConfig,
  275. in_dim: int,
  276. out_dim: int,
  277. stride: int,
  278. expand_ratio: int,
  279. kernel_size: int,
  280. drop_rate: float,
  281. id_skip: bool,
  282. adjust_padding: bool,
  283. ):
  284. super().__init__()
  285. self.expand_ratio = expand_ratio
  286. self.expand = True if self.expand_ratio != 1 else False
  287. expand_in_dim = in_dim * expand_ratio
  288. if self.expand:
  289. self.expansion = EfficientNetExpansionLayer(
  290. config=config, in_dim=in_dim, out_dim=expand_in_dim, stride=stride
  291. )
  292. self.depthwise_conv = EfficientNetDepthwiseLayer(
  293. config=config,
  294. in_dim=expand_in_dim if self.expand else in_dim,
  295. stride=stride,
  296. kernel_size=kernel_size,
  297. adjust_padding=adjust_padding,
  298. )
  299. self.squeeze_excite = EfficientNetSqueezeExciteLayer(
  300. config=config, in_dim=in_dim, expand_dim=expand_in_dim, expand=self.expand
  301. )
  302. self.projection = EfficientNetFinalBlockLayer(
  303. config=config,
  304. in_dim=expand_in_dim if self.expand else in_dim,
  305. out_dim=out_dim,
  306. stride=stride,
  307. drop_rate=drop_rate,
  308. id_skip=id_skip,
  309. )
  310. def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
  311. embeddings = hidden_states
  312. # Expansion and depthwise convolution phase
  313. if self.expand_ratio != 1:
  314. hidden_states = self.expansion(hidden_states)
  315. hidden_states = self.depthwise_conv(hidden_states)
  316. # Squeeze and excite phase
  317. hidden_states = self.squeeze_excite(hidden_states)
  318. hidden_states = self.projection(embeddings, hidden_states)
  319. return hidden_states
  320. class EfficientNetEncoder(nn.Module):
  321. r"""
  322. Forward propogates the embeddings through each EfficientNet block.
  323. Args:
  324. config ([`EfficientNetConfig`]):
  325. Model configuration class.
  326. """
  327. def __init__(self, config: EfficientNetConfig):
  328. super().__init__()
  329. self.config = config
  330. self.depth_coefficient = config.depth_coefficient
  331. def round_repeats(repeats):
  332. # Round number of block repeats based on depth multiplier.
  333. return int(math.ceil(self.depth_coefficient * repeats))
  334. num_base_blocks = len(config.in_channels)
  335. num_blocks = sum(round_repeats(n) for n in config.num_block_repeats)
  336. curr_block_num = 0
  337. blocks = []
  338. for i in range(num_base_blocks):
  339. in_dim = round_filters(config, config.in_channels[i])
  340. out_dim = round_filters(config, config.out_channels[i])
  341. stride = config.strides[i]
  342. kernel_size = config.kernel_sizes[i]
  343. expand_ratio = config.expand_ratios[i]
  344. for j in range(round_repeats(config.num_block_repeats[i])):
  345. id_skip = True if j == 0 else False
  346. stride = 1 if j > 0 else stride
  347. in_dim = out_dim if j > 0 else in_dim
  348. adjust_padding = False if curr_block_num in config.depthwise_padding else True
  349. drop_rate = config.drop_connect_rate * curr_block_num / num_blocks
  350. block = EfficientNetBlock(
  351. config=config,
  352. in_dim=in_dim,
  353. out_dim=out_dim,
  354. stride=stride,
  355. kernel_size=kernel_size,
  356. expand_ratio=expand_ratio,
  357. drop_rate=drop_rate,
  358. id_skip=id_skip,
  359. adjust_padding=adjust_padding,
  360. )
  361. blocks.append(block)
  362. curr_block_num += 1
  363. self.blocks = nn.ModuleList(blocks)
  364. self.top_conv = nn.Conv2d(
  365. in_channels=out_dim,
  366. out_channels=round_filters(config, 1280),
  367. kernel_size=1,
  368. padding="same",
  369. bias=False,
  370. )
  371. self.top_bn = nn.BatchNorm2d(
  372. num_features=config.hidden_dim, eps=config.batch_norm_eps, momentum=config.batch_norm_momentum
  373. )
  374. self.top_activation = ACT2FN[config.hidden_act]
  375. def forward(
  376. self,
  377. hidden_states: torch.FloatTensor,
  378. output_hidden_states: Optional[bool] = False,
  379. return_dict: Optional[bool] = True,
  380. ) -> BaseModelOutputWithNoAttention:
  381. all_hidden_states = (hidden_states,) if output_hidden_states else None
  382. for block in self.blocks:
  383. hidden_states = block(hidden_states)
  384. if output_hidden_states:
  385. all_hidden_states += (hidden_states,)
  386. hidden_states = self.top_conv(hidden_states)
  387. hidden_states = self.top_bn(hidden_states)
  388. hidden_states = self.top_activation(hidden_states)
  389. if not return_dict:
  390. return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
  391. return BaseModelOutputWithNoAttention(
  392. last_hidden_state=hidden_states,
  393. hidden_states=all_hidden_states,
  394. )
  395. class EfficientNetPreTrainedModel(PreTrainedModel):
  396. """
  397. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  398. models.
  399. """
  400. config_class = EfficientNetConfig
  401. base_model_prefix = "efficientnet"
  402. main_input_name = "pixel_values"
  403. _no_split_modules = []
  404. def _init_weights(self, module):
  405. """Initialize the weights"""
  406. if isinstance(module, (nn.Linear, nn.Conv2d)):
  407. # Slightly different from the TF version which uses truncated_normal for initialization
  408. # cf https://github.com/pytorch/pytorch/pull/5617
  409. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  410. if module.bias is not None:
  411. module.bias.data.zero_()
  412. elif isinstance(module, nn.LayerNorm):
  413. module.bias.data.zero_()
  414. module.weight.data.fill_(1.0)
  415. @add_start_docstrings(
  416. "The bare EfficientNet model outputting raw features without any specific head on top.",
  417. EFFICIENTNET_START_DOCSTRING,
  418. )
  419. class EfficientNetModel(EfficientNetPreTrainedModel):
  420. def __init__(self, config: EfficientNetConfig):
  421. super().__init__(config)
  422. self.config = config
  423. self.embeddings = EfficientNetEmbeddings(config)
  424. self.encoder = EfficientNetEncoder(config)
  425. # Final pooling layer
  426. if config.pooling_type == "mean":
  427. self.pooler = nn.AvgPool2d(config.hidden_dim, ceil_mode=True)
  428. elif config.pooling_type == "max":
  429. self.pooler = nn.MaxPool2d(config.hidden_dim, ceil_mode=True)
  430. else:
  431. raise ValueError(f"config.pooling must be one of ['mean', 'max'] got {config.pooling}")
  432. # Initialize weights and apply final processing
  433. self.post_init()
  434. @add_start_docstrings_to_model_forward(EFFICIENTNET_INPUTS_DOCSTRING)
  435. @add_code_sample_docstrings(
  436. checkpoint=_CHECKPOINT_FOR_DOC,
  437. output_type=BaseModelOutputWithPoolingAndNoAttention,
  438. config_class=_CONFIG_FOR_DOC,
  439. modality="vision",
  440. expected_output=_EXPECTED_OUTPUT_SHAPE,
  441. )
  442. def forward(
  443. self,
  444. pixel_values: torch.FloatTensor = None,
  445. output_hidden_states: Optional[bool] = None,
  446. return_dict: Optional[bool] = None,
  447. ) -> Union[Tuple, BaseModelOutputWithPoolingAndNoAttention]:
  448. output_hidden_states = (
  449. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  450. )
  451. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  452. if pixel_values is None:
  453. raise ValueError("You have to specify pixel_values")
  454. embedding_output = self.embeddings(pixel_values)
  455. encoder_outputs = self.encoder(
  456. embedding_output,
  457. output_hidden_states=output_hidden_states,
  458. return_dict=return_dict,
  459. )
  460. # Apply pooling
  461. last_hidden_state = encoder_outputs[0]
  462. pooled_output = self.pooler(last_hidden_state)
  463. # Reshape (batch_size, 1280, 1 , 1) -> (batch_size, 1280)
  464. pooled_output = pooled_output.reshape(pooled_output.shape[:2])
  465. if not return_dict:
  466. return (last_hidden_state, pooled_output) + encoder_outputs[1:]
  467. return BaseModelOutputWithPoolingAndNoAttention(
  468. last_hidden_state=last_hidden_state,
  469. pooler_output=pooled_output,
  470. hidden_states=encoder_outputs.hidden_states,
  471. )
  472. @add_start_docstrings(
  473. """
  474. EfficientNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g.
  475. for ImageNet.
  476. """,
  477. EFFICIENTNET_START_DOCSTRING,
  478. )
  479. class EfficientNetForImageClassification(EfficientNetPreTrainedModel):
  480. def __init__(self, config):
  481. super().__init__(config)
  482. self.num_labels = config.num_labels
  483. self.config = config
  484. self.efficientnet = EfficientNetModel(config)
  485. # Classifier head
  486. self.dropout = nn.Dropout(p=config.dropout_rate)
  487. self.classifier = nn.Linear(config.hidden_dim, self.num_labels) if self.num_labels > 0 else nn.Identity()
  488. # Initialize weights and apply final processing
  489. self.post_init()
  490. @add_start_docstrings_to_model_forward(EFFICIENTNET_INPUTS_DOCSTRING)
  491. @add_code_sample_docstrings(
  492. checkpoint=_IMAGE_CLASS_CHECKPOINT,
  493. output_type=ImageClassifierOutputWithNoAttention,
  494. config_class=_CONFIG_FOR_DOC,
  495. expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
  496. )
  497. def forward(
  498. self,
  499. pixel_values: torch.FloatTensor = None,
  500. labels: Optional[torch.LongTensor] = None,
  501. output_hidden_states: Optional[bool] = None,
  502. return_dict: Optional[bool] = None,
  503. ) -> Union[Tuple, ImageClassifierOutputWithNoAttention]:
  504. r"""
  505. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  506. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  507. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  508. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  509. """
  510. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  511. outputs = self.efficientnet(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
  512. pooled_output = outputs.pooler_output if return_dict else outputs[1]
  513. pooled_output = self.dropout(pooled_output)
  514. logits = self.classifier(pooled_output)
  515. loss = None
  516. if labels is not None:
  517. if self.config.problem_type is None:
  518. if self.num_labels == 1:
  519. self.config.problem_type = "regression"
  520. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  521. self.config.problem_type = "single_label_classification"
  522. else:
  523. self.config.problem_type = "multi_label_classification"
  524. if self.config.problem_type == "regression":
  525. loss_fct = MSELoss()
  526. if self.num_labels == 1:
  527. loss = loss_fct(logits.squeeze(), labels.squeeze())
  528. else:
  529. loss = loss_fct(logits, labels)
  530. elif self.config.problem_type == "single_label_classification":
  531. loss_fct = CrossEntropyLoss()
  532. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  533. elif self.config.problem_type == "multi_label_classification":
  534. loss_fct = BCEWithLogitsLoss()
  535. loss = loss_fct(logits, labels)
  536. if not return_dict:
  537. output = (logits,) + outputs[2:]
  538. return ((loss,) + output) if loss is not None else output
  539. return ImageClassifierOutputWithNoAttention(
  540. loss=loss,
  541. logits=logits,
  542. hidden_states=outputs.hidden_states,
  543. )