modeling_tf_mobilevit.py 53 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370
  1. # coding=utf-8
  2. # Copyright 2022 Apple Inc. and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. # Original license: https://github.com/apple/ml-cvnets/blob/main/LICENSE
  17. """TensorFlow 2.0 MobileViT model."""
  18. from __future__ import annotations
  19. from typing import Dict, Optional, Tuple, Union
  20. import tensorflow as tf
  21. from ...activations_tf import get_tf_activation
  22. from ...file_utils import (
  23. add_code_sample_docstrings,
  24. add_start_docstrings,
  25. add_start_docstrings_to_model_forward,
  26. replace_return_docstrings,
  27. )
  28. from ...modeling_tf_outputs import (
  29. TFBaseModelOutput,
  30. TFBaseModelOutputWithPooling,
  31. TFImageClassifierOutputWithNoAttention,
  32. TFSemanticSegmenterOutputWithNoAttention,
  33. )
  34. from ...modeling_tf_utils import (
  35. TFPreTrainedModel,
  36. TFSequenceClassificationLoss,
  37. keras,
  38. keras_serializable,
  39. unpack_inputs,
  40. )
  41. from ...tf_utils import shape_list, stable_softmax
  42. from ...utils import logging
  43. from .configuration_mobilevit import MobileViTConfig
  44. logger = logging.get_logger(__name__)
  45. # General docstring
  46. _CONFIG_FOR_DOC = "MobileViTConfig"
  47. # Base docstring
  48. _CHECKPOINT_FOR_DOC = "apple/mobilevit-small"
  49. _EXPECTED_OUTPUT_SHAPE = [1, 640, 8, 8]
  50. # Image classification docstring
  51. _IMAGE_CLASS_CHECKPOINT = "apple/mobilevit-small"
  52. _IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
  53. def make_divisible(value: int, divisor: int = 8, min_value: Optional[int] = None) -> int:
  54. """
  55. Ensure that all layers have a channel count that is divisible by `divisor`. This function is taken from the
  56. original TensorFlow repo. It can be seen here:
  57. https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
  58. """
  59. if min_value is None:
  60. min_value = divisor
  61. new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
  62. # Make sure that round down does not go down by more than 10%.
  63. if new_value < 0.9 * value:
  64. new_value += divisor
  65. return int(new_value)
  66. class TFMobileViTConvLayer(keras.layers.Layer):
  67. def __init__(
  68. self,
  69. config: MobileViTConfig,
  70. in_channels: int,
  71. out_channels: int,
  72. kernel_size: int,
  73. stride: int = 1,
  74. groups: int = 1,
  75. bias: bool = False,
  76. dilation: int = 1,
  77. use_normalization: bool = True,
  78. use_activation: Union[bool, str] = True,
  79. **kwargs,
  80. ) -> None:
  81. super().__init__(**kwargs)
  82. logger.warning(
  83. f"\n{self.__class__.__name__} has backpropagation operations that are NOT supported on CPU. If you wish "
  84. "to train/fine-tune this model, you need a GPU or a TPU"
  85. )
  86. padding = int((kernel_size - 1) / 2) * dilation
  87. self.padding = keras.layers.ZeroPadding2D(padding)
  88. if out_channels % groups != 0:
  89. raise ValueError(f"Output channels ({out_channels}) are not divisible by {groups} groups.")
  90. self.convolution = keras.layers.Conv2D(
  91. filters=out_channels,
  92. kernel_size=kernel_size,
  93. strides=stride,
  94. padding="VALID",
  95. dilation_rate=dilation,
  96. groups=groups,
  97. use_bias=bias,
  98. name="convolution",
  99. )
  100. if use_normalization:
  101. self.normalization = keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1, name="normalization")
  102. else:
  103. self.normalization = None
  104. if use_activation:
  105. if isinstance(use_activation, str):
  106. self.activation = get_tf_activation(use_activation)
  107. elif isinstance(config.hidden_act, str):
  108. self.activation = get_tf_activation(config.hidden_act)
  109. else:
  110. self.activation = config.hidden_act
  111. else:
  112. self.activation = None
  113. self.in_channels = in_channels
  114. self.out_channels = out_channels
  115. def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor:
  116. padded_features = self.padding(features)
  117. features = self.convolution(padded_features)
  118. if self.normalization is not None:
  119. features = self.normalization(features, training=training)
  120. if self.activation is not None:
  121. features = self.activation(features)
  122. return features
  123. def build(self, input_shape=None):
  124. if self.built:
  125. return
  126. self.built = True
  127. if getattr(self, "convolution", None) is not None:
  128. with tf.name_scope(self.convolution.name):
  129. self.convolution.build([None, None, None, self.in_channels])
  130. if getattr(self, "normalization", None) is not None:
  131. if hasattr(self.normalization, "name"):
  132. with tf.name_scope(self.normalization.name):
  133. self.normalization.build([None, None, None, self.out_channels])
  134. class TFMobileViTInvertedResidual(keras.layers.Layer):
  135. """
  136. Inverted residual block (MobileNetv2): https://arxiv.org/abs/1801.04381
  137. """
  138. def __init__(
  139. self, config: MobileViTConfig, in_channels: int, out_channels: int, stride: int, dilation: int = 1, **kwargs
  140. ) -> None:
  141. super().__init__(**kwargs)
  142. expanded_channels = make_divisible(int(round(in_channels * config.expand_ratio)), 8)
  143. if stride not in [1, 2]:
  144. raise ValueError(f"Invalid stride {stride}.")
  145. self.use_residual = (stride == 1) and (in_channels == out_channels)
  146. self.expand_1x1 = TFMobileViTConvLayer(
  147. config, in_channels=in_channels, out_channels=expanded_channels, kernel_size=1, name="expand_1x1"
  148. )
  149. self.conv_3x3 = TFMobileViTConvLayer(
  150. config,
  151. in_channels=expanded_channels,
  152. out_channels=expanded_channels,
  153. kernel_size=3,
  154. stride=stride,
  155. groups=expanded_channels,
  156. dilation=dilation,
  157. name="conv_3x3",
  158. )
  159. self.reduce_1x1 = TFMobileViTConvLayer(
  160. config,
  161. in_channels=expanded_channels,
  162. out_channels=out_channels,
  163. kernel_size=1,
  164. use_activation=False,
  165. name="reduce_1x1",
  166. )
  167. def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor:
  168. residual = features
  169. features = self.expand_1x1(features, training=training)
  170. features = self.conv_3x3(features, training=training)
  171. features = self.reduce_1x1(features, training=training)
  172. return residual + features if self.use_residual else features
  173. def build(self, input_shape=None):
  174. if self.built:
  175. return
  176. self.built = True
  177. if getattr(self, "expand_1x1", None) is not None:
  178. with tf.name_scope(self.expand_1x1.name):
  179. self.expand_1x1.build(None)
  180. if getattr(self, "conv_3x3", None) is not None:
  181. with tf.name_scope(self.conv_3x3.name):
  182. self.conv_3x3.build(None)
  183. if getattr(self, "reduce_1x1", None) is not None:
  184. with tf.name_scope(self.reduce_1x1.name):
  185. self.reduce_1x1.build(None)
  186. class TFMobileViTMobileNetLayer(keras.layers.Layer):
  187. def __init__(
  188. self,
  189. config: MobileViTConfig,
  190. in_channels: int,
  191. out_channels: int,
  192. stride: int = 1,
  193. num_stages: int = 1,
  194. **kwargs,
  195. ) -> None:
  196. super().__init__(**kwargs)
  197. self.layers = []
  198. for i in range(num_stages):
  199. layer = TFMobileViTInvertedResidual(
  200. config,
  201. in_channels=in_channels,
  202. out_channels=out_channels,
  203. stride=stride if i == 0 else 1,
  204. name=f"layer.{i}",
  205. )
  206. self.layers.append(layer)
  207. in_channels = out_channels
  208. def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor:
  209. for layer_module in self.layers:
  210. features = layer_module(features, training=training)
  211. return features
  212. def build(self, input_shape=None):
  213. if self.built:
  214. return
  215. self.built = True
  216. if getattr(self, "layers", None) is not None:
  217. for layer_module in self.layers:
  218. with tf.name_scope(layer_module.name):
  219. layer_module.build(None)
  220. class TFMobileViTSelfAttention(keras.layers.Layer):
  221. def __init__(self, config: MobileViTConfig, hidden_size: int, **kwargs) -> None:
  222. super().__init__(**kwargs)
  223. if hidden_size % config.num_attention_heads != 0:
  224. raise ValueError(
  225. f"The hidden size {hidden_size,} is not a multiple of the number of attention "
  226. f"heads {config.num_attention_heads}."
  227. )
  228. self.num_attention_heads = config.num_attention_heads
  229. self.attention_head_size = int(hidden_size / config.num_attention_heads)
  230. self.all_head_size = self.num_attention_heads * self.attention_head_size
  231. scale = tf.cast(self.attention_head_size, dtype=tf.float32)
  232. self.scale = tf.math.sqrt(scale)
  233. self.query = keras.layers.Dense(self.all_head_size, use_bias=config.qkv_bias, name="query")
  234. self.key = keras.layers.Dense(self.all_head_size, use_bias=config.qkv_bias, name="key")
  235. self.value = keras.layers.Dense(self.all_head_size, use_bias=config.qkv_bias, name="value")
  236. self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob)
  237. self.hidden_size = hidden_size
  238. def transpose_for_scores(self, x: tf.Tensor) -> tf.Tensor:
  239. batch_size = tf.shape(x)[0]
  240. x = tf.reshape(x, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
  241. return tf.transpose(x, perm=[0, 2, 1, 3])
  242. def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
  243. batch_size = tf.shape(hidden_states)[0]
  244. key_layer = self.transpose_for_scores(self.key(hidden_states))
  245. value_layer = self.transpose_for_scores(self.value(hidden_states))
  246. query_layer = self.transpose_for_scores(self.query(hidden_states))
  247. # Take the dot product between "query" and "key" to get the raw attention scores.
  248. attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
  249. attention_scores = attention_scores / self.scale
  250. # Normalize the attention scores to probabilities.
  251. attention_probs = stable_softmax(attention_scores, axis=-1)
  252. # This is actually dropping out entire tokens to attend to, which might
  253. # seem a bit unusual, but is taken from the original Transformer paper.
  254. attention_probs = self.dropout(attention_probs, training=training)
  255. context_layer = tf.matmul(attention_probs, value_layer)
  256. context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
  257. context_layer = tf.reshape(context_layer, shape=(batch_size, -1, self.all_head_size))
  258. return context_layer
  259. def build(self, input_shape=None):
  260. if self.built:
  261. return
  262. self.built = True
  263. if getattr(self, "query", None) is not None:
  264. with tf.name_scope(self.query.name):
  265. self.query.build([None, None, self.hidden_size])
  266. if getattr(self, "key", None) is not None:
  267. with tf.name_scope(self.key.name):
  268. self.key.build([None, None, self.hidden_size])
  269. if getattr(self, "value", None) is not None:
  270. with tf.name_scope(self.value.name):
  271. self.value.build([None, None, self.hidden_size])
  272. class TFMobileViTSelfOutput(keras.layers.Layer):
  273. def __init__(self, config: MobileViTConfig, hidden_size: int, **kwargs) -> None:
  274. super().__init__(**kwargs)
  275. self.dense = keras.layers.Dense(hidden_size, name="dense")
  276. self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
  277. self.hidden_size = hidden_size
  278. def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
  279. hidden_states = self.dense(hidden_states)
  280. hidden_states = self.dropout(hidden_states, training=training)
  281. return hidden_states
  282. def build(self, input_shape=None):
  283. if self.built:
  284. return
  285. self.built = True
  286. if getattr(self, "dense", None) is not None:
  287. with tf.name_scope(self.dense.name):
  288. self.dense.build([None, None, self.hidden_size])
  289. class TFMobileViTAttention(keras.layers.Layer):
  290. def __init__(self, config: MobileViTConfig, hidden_size: int, **kwargs) -> None:
  291. super().__init__(**kwargs)
  292. self.attention = TFMobileViTSelfAttention(config, hidden_size, name="attention")
  293. self.dense_output = TFMobileViTSelfOutput(config, hidden_size, name="output")
  294. def prune_heads(self, heads):
  295. raise NotImplementedError
  296. def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
  297. self_outputs = self.attention(hidden_states, training=training)
  298. attention_output = self.dense_output(self_outputs, training=training)
  299. return attention_output
  300. def build(self, input_shape=None):
  301. if self.built:
  302. return
  303. self.built = True
  304. if getattr(self, "attention", None) is not None:
  305. with tf.name_scope(self.attention.name):
  306. self.attention.build(None)
  307. if getattr(self, "dense_output", None) is not None:
  308. with tf.name_scope(self.dense_output.name):
  309. self.dense_output.build(None)
  310. class TFMobileViTIntermediate(keras.layers.Layer):
  311. def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int, **kwargs) -> None:
  312. super().__init__(**kwargs)
  313. self.dense = keras.layers.Dense(intermediate_size, name="dense")
  314. if isinstance(config.hidden_act, str):
  315. self.intermediate_act_fn = get_tf_activation(config.hidden_act)
  316. else:
  317. self.intermediate_act_fn = config.hidden_act
  318. self.hidden_size = hidden_size
  319. def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
  320. hidden_states = self.dense(hidden_states)
  321. hidden_states = self.intermediate_act_fn(hidden_states)
  322. return hidden_states
  323. def build(self, input_shape=None):
  324. if self.built:
  325. return
  326. self.built = True
  327. if getattr(self, "dense", None) is not None:
  328. with tf.name_scope(self.dense.name):
  329. self.dense.build([None, None, self.hidden_size])
  330. class TFMobileViTOutput(keras.layers.Layer):
  331. def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int, **kwargs) -> None:
  332. super().__init__(**kwargs)
  333. self.dense = keras.layers.Dense(hidden_size, name="dense")
  334. self.dropout = keras.layers.Dropout(config.hidden_dropout_prob)
  335. self.intermediate_size = intermediate_size
  336. def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
  337. hidden_states = self.dense(hidden_states)
  338. hidden_states = self.dropout(hidden_states, training=training)
  339. hidden_states = hidden_states + input_tensor
  340. return hidden_states
  341. def build(self, input_shape=None):
  342. if self.built:
  343. return
  344. self.built = True
  345. if getattr(self, "dense", None) is not None:
  346. with tf.name_scope(self.dense.name):
  347. self.dense.build([None, None, self.intermediate_size])
  348. class TFMobileViTTransformerLayer(keras.layers.Layer):
  349. def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int, **kwargs) -> None:
  350. super().__init__(**kwargs)
  351. self.attention = TFMobileViTAttention(config, hidden_size, name="attention")
  352. self.intermediate = TFMobileViTIntermediate(config, hidden_size, intermediate_size, name="intermediate")
  353. self.mobilevit_output = TFMobileViTOutput(config, hidden_size, intermediate_size, name="output")
  354. self.layernorm_before = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_before")
  355. self.layernorm_after = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_after")
  356. self.hidden_size = hidden_size
  357. def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
  358. attention_output = self.attention(self.layernorm_before(hidden_states), training=training)
  359. hidden_states = attention_output + hidden_states
  360. layer_output = self.layernorm_after(hidden_states)
  361. layer_output = self.intermediate(layer_output)
  362. layer_output = self.mobilevit_output(layer_output, hidden_states, training=training)
  363. return layer_output
  364. def build(self, input_shape=None):
  365. if self.built:
  366. return
  367. self.built = True
  368. if getattr(self, "attention", None) is not None:
  369. with tf.name_scope(self.attention.name):
  370. self.attention.build(None)
  371. if getattr(self, "intermediate", None) is not None:
  372. with tf.name_scope(self.intermediate.name):
  373. self.intermediate.build(None)
  374. if getattr(self, "mobilevit_output", None) is not None:
  375. with tf.name_scope(self.mobilevit_output.name):
  376. self.mobilevit_output.build(None)
  377. if getattr(self, "layernorm_before", None) is not None:
  378. with tf.name_scope(self.layernorm_before.name):
  379. self.layernorm_before.build([None, None, self.hidden_size])
  380. if getattr(self, "layernorm_after", None) is not None:
  381. with tf.name_scope(self.layernorm_after.name):
  382. self.layernorm_after.build([None, None, self.hidden_size])
  383. class TFMobileViTTransformer(keras.layers.Layer):
  384. def __init__(self, config: MobileViTConfig, hidden_size: int, num_stages: int, **kwargs) -> None:
  385. super().__init__(**kwargs)
  386. self.layers = []
  387. for i in range(num_stages):
  388. transformer_layer = TFMobileViTTransformerLayer(
  389. config,
  390. hidden_size=hidden_size,
  391. intermediate_size=int(hidden_size * config.mlp_ratio),
  392. name=f"layer.{i}",
  393. )
  394. self.layers.append(transformer_layer)
  395. def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
  396. for layer_module in self.layers:
  397. hidden_states = layer_module(hidden_states, training=training)
  398. return hidden_states
  399. def build(self, input_shape=None):
  400. if self.built:
  401. return
  402. self.built = True
  403. if getattr(self, "layers", None) is not None:
  404. for layer_module in self.layers:
  405. with tf.name_scope(layer_module.name):
  406. layer_module.build(None)
  407. class TFMobileViTLayer(keras.layers.Layer):
  408. """
  409. MobileViT block: https://arxiv.org/abs/2110.02178
  410. """
  411. def __init__(
  412. self,
  413. config: MobileViTConfig,
  414. in_channels: int,
  415. out_channels: int,
  416. stride: int,
  417. hidden_size: int,
  418. num_stages: int,
  419. dilation: int = 1,
  420. **kwargs,
  421. ) -> None:
  422. super().__init__(**kwargs)
  423. self.patch_width = config.patch_size
  424. self.patch_height = config.patch_size
  425. if stride == 2:
  426. self.downsampling_layer = TFMobileViTInvertedResidual(
  427. config,
  428. in_channels=in_channels,
  429. out_channels=out_channels,
  430. stride=stride if dilation == 1 else 1,
  431. dilation=dilation // 2 if dilation > 1 else 1,
  432. name="downsampling_layer",
  433. )
  434. in_channels = out_channels
  435. else:
  436. self.downsampling_layer = None
  437. self.conv_kxk = TFMobileViTConvLayer(
  438. config,
  439. in_channels=in_channels,
  440. out_channels=in_channels,
  441. kernel_size=config.conv_kernel_size,
  442. name="conv_kxk",
  443. )
  444. self.conv_1x1 = TFMobileViTConvLayer(
  445. config,
  446. in_channels=in_channels,
  447. out_channels=hidden_size,
  448. kernel_size=1,
  449. use_normalization=False,
  450. use_activation=False,
  451. name="conv_1x1",
  452. )
  453. self.transformer = TFMobileViTTransformer(
  454. config, hidden_size=hidden_size, num_stages=num_stages, name="transformer"
  455. )
  456. self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
  457. self.conv_projection = TFMobileViTConvLayer(
  458. config, in_channels=hidden_size, out_channels=in_channels, kernel_size=1, name="conv_projection"
  459. )
  460. self.fusion = TFMobileViTConvLayer(
  461. config,
  462. in_channels=2 * in_channels,
  463. out_channels=in_channels,
  464. kernel_size=config.conv_kernel_size,
  465. name="fusion",
  466. )
  467. self.hidden_size = hidden_size
  468. def unfolding(self, features: tf.Tensor) -> Tuple[tf.Tensor, Dict]:
  469. patch_width, patch_height = self.patch_width, self.patch_height
  470. patch_area = tf.cast(patch_width * patch_height, "int32")
  471. batch_size = tf.shape(features)[0]
  472. orig_height = tf.shape(features)[1]
  473. orig_width = tf.shape(features)[2]
  474. channels = tf.shape(features)[3]
  475. new_height = tf.cast(tf.math.ceil(orig_height / patch_height) * patch_height, "int32")
  476. new_width = tf.cast(tf.math.ceil(orig_width / patch_width) * patch_width, "int32")
  477. interpolate = new_width != orig_width or new_height != orig_height
  478. if interpolate:
  479. # Note: Padding can be done, but then it needs to be handled in attention function.
  480. features = tf.image.resize(features, size=(new_height, new_width), method="bilinear")
  481. # number of patches along width and height
  482. num_patch_width = new_width // patch_width
  483. num_patch_height = new_height // patch_height
  484. num_patches = num_patch_height * num_patch_width
  485. # convert from shape (batch_size, orig_height, orig_width, channels)
  486. # to the shape (batch_size * patch_area, num_patches, channels)
  487. features = tf.transpose(features, [0, 3, 1, 2])
  488. patches = tf.reshape(
  489. features, (batch_size * channels * num_patch_height, patch_height, num_patch_width, patch_width)
  490. )
  491. patches = tf.transpose(patches, [0, 2, 1, 3])
  492. patches = tf.reshape(patches, (batch_size, channels, num_patches, patch_area))
  493. patches = tf.transpose(patches, [0, 3, 2, 1])
  494. patches = tf.reshape(patches, (batch_size * patch_area, num_patches, channels))
  495. info_dict = {
  496. "orig_size": (orig_height, orig_width),
  497. "batch_size": batch_size,
  498. "channels": channels,
  499. "interpolate": interpolate,
  500. "num_patches": num_patches,
  501. "num_patches_width": num_patch_width,
  502. "num_patches_height": num_patch_height,
  503. }
  504. return patches, info_dict
  505. def folding(self, patches: tf.Tensor, info_dict: Dict) -> tf.Tensor:
  506. patch_width, patch_height = self.patch_width, self.patch_height
  507. patch_area = int(patch_width * patch_height)
  508. batch_size = info_dict["batch_size"]
  509. channels = info_dict["channels"]
  510. num_patches = info_dict["num_patches"]
  511. num_patch_height = info_dict["num_patches_height"]
  512. num_patch_width = info_dict["num_patches_width"]
  513. # convert from shape (batch_size * patch_area, num_patches, channels)
  514. # back to shape (batch_size, channels, orig_height, orig_width)
  515. features = tf.reshape(patches, (batch_size, patch_area, num_patches, -1))
  516. features = tf.transpose(features, perm=(0, 3, 2, 1))
  517. features = tf.reshape(
  518. features, (batch_size * channels * num_patch_height, num_patch_width, patch_height, patch_width)
  519. )
  520. features = tf.transpose(features, perm=(0, 2, 1, 3))
  521. features = tf.reshape(
  522. features, (batch_size, channels, num_patch_height * patch_height, num_patch_width * patch_width)
  523. )
  524. features = tf.transpose(features, perm=(0, 2, 3, 1))
  525. if info_dict["interpolate"]:
  526. features = tf.image.resize(features, size=info_dict["orig_size"], method="bilinear")
  527. return features
  528. def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor:
  529. # reduce spatial dimensions if needed
  530. if self.downsampling_layer:
  531. features = self.downsampling_layer(features, training=training)
  532. residual = features
  533. # local representation
  534. features = self.conv_kxk(features, training=training)
  535. features = self.conv_1x1(features, training=training)
  536. # convert feature map to patches
  537. patches, info_dict = self.unfolding(features)
  538. # learn global representations
  539. patches = self.transformer(patches, training=training)
  540. patches = self.layernorm(patches)
  541. # convert patches back to feature maps
  542. features = self.folding(patches, info_dict)
  543. features = self.conv_projection(features, training=training)
  544. features = self.fusion(tf.concat([residual, features], axis=-1), training=training)
  545. return features
  546. def build(self, input_shape=None):
  547. if self.built:
  548. return
  549. self.built = True
  550. if getattr(self, "conv_kxk", None) is not None:
  551. with tf.name_scope(self.conv_kxk.name):
  552. self.conv_kxk.build(None)
  553. if getattr(self, "conv_1x1", None) is not None:
  554. with tf.name_scope(self.conv_1x1.name):
  555. self.conv_1x1.build(None)
  556. if getattr(self, "transformer", None) is not None:
  557. with tf.name_scope(self.transformer.name):
  558. self.transformer.build(None)
  559. if getattr(self, "layernorm", None) is not None:
  560. with tf.name_scope(self.layernorm.name):
  561. self.layernorm.build([None, None, self.hidden_size])
  562. if getattr(self, "conv_projection", None) is not None:
  563. with tf.name_scope(self.conv_projection.name):
  564. self.conv_projection.build(None)
  565. if getattr(self, "fusion", None) is not None:
  566. with tf.name_scope(self.fusion.name):
  567. self.fusion.build(None)
  568. if getattr(self, "downsampling_layer", None) is not None:
  569. with tf.name_scope(self.downsampling_layer.name):
  570. self.downsampling_layer.build(None)
  571. class TFMobileViTEncoder(keras.layers.Layer):
  572. def __init__(self, config: MobileViTConfig, **kwargs) -> None:
  573. super().__init__(**kwargs)
  574. self.config = config
  575. self.layers = []
  576. # segmentation architectures like DeepLab and PSPNet modify the strides
  577. # of the classification backbones
  578. dilate_layer_4 = dilate_layer_5 = False
  579. if config.output_stride == 8:
  580. dilate_layer_4 = True
  581. dilate_layer_5 = True
  582. elif config.output_stride == 16:
  583. dilate_layer_5 = True
  584. dilation = 1
  585. layer_1 = TFMobileViTMobileNetLayer(
  586. config,
  587. in_channels=config.neck_hidden_sizes[0],
  588. out_channels=config.neck_hidden_sizes[1],
  589. stride=1,
  590. num_stages=1,
  591. name="layer.0",
  592. )
  593. self.layers.append(layer_1)
  594. layer_2 = TFMobileViTMobileNetLayer(
  595. config,
  596. in_channels=config.neck_hidden_sizes[1],
  597. out_channels=config.neck_hidden_sizes[2],
  598. stride=2,
  599. num_stages=3,
  600. name="layer.1",
  601. )
  602. self.layers.append(layer_2)
  603. layer_3 = TFMobileViTLayer(
  604. config,
  605. in_channels=config.neck_hidden_sizes[2],
  606. out_channels=config.neck_hidden_sizes[3],
  607. stride=2,
  608. hidden_size=config.hidden_sizes[0],
  609. num_stages=2,
  610. name="layer.2",
  611. )
  612. self.layers.append(layer_3)
  613. if dilate_layer_4:
  614. dilation *= 2
  615. layer_4 = TFMobileViTLayer(
  616. config,
  617. in_channels=config.neck_hidden_sizes[3],
  618. out_channels=config.neck_hidden_sizes[4],
  619. stride=2,
  620. hidden_size=config.hidden_sizes[1],
  621. num_stages=4,
  622. dilation=dilation,
  623. name="layer.3",
  624. )
  625. self.layers.append(layer_4)
  626. if dilate_layer_5:
  627. dilation *= 2
  628. layer_5 = TFMobileViTLayer(
  629. config,
  630. in_channels=config.neck_hidden_sizes[4],
  631. out_channels=config.neck_hidden_sizes[5],
  632. stride=2,
  633. hidden_size=config.hidden_sizes[2],
  634. num_stages=3,
  635. dilation=dilation,
  636. name="layer.4",
  637. )
  638. self.layers.append(layer_5)
  639. def call(
  640. self,
  641. hidden_states: tf.Tensor,
  642. output_hidden_states: bool = False,
  643. return_dict: bool = True,
  644. training: bool = False,
  645. ) -> Union[tuple, TFBaseModelOutput]:
  646. all_hidden_states = () if output_hidden_states else None
  647. for i, layer_module in enumerate(self.layers):
  648. hidden_states = layer_module(hidden_states, training=training)
  649. if output_hidden_states:
  650. all_hidden_states = all_hidden_states + (hidden_states,)
  651. if not return_dict:
  652. return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
  653. return TFBaseModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
  654. def build(self, input_shape=None):
  655. if self.built:
  656. return
  657. self.built = True
  658. if getattr(self, "layers", None) is not None:
  659. for layer_module in self.layers:
  660. with tf.name_scope(layer_module.name):
  661. layer_module.build(None)
  662. @keras_serializable
  663. class TFMobileViTMainLayer(keras.layers.Layer):
  664. config_class = MobileViTConfig
  665. def __init__(self, config: MobileViTConfig, expand_output: bool = True, **kwargs):
  666. super().__init__(**kwargs)
  667. self.config = config
  668. self.expand_output = expand_output
  669. self.conv_stem = TFMobileViTConvLayer(
  670. config,
  671. in_channels=config.num_channels,
  672. out_channels=config.neck_hidden_sizes[0],
  673. kernel_size=3,
  674. stride=2,
  675. name="conv_stem",
  676. )
  677. self.encoder = TFMobileViTEncoder(config, name="encoder")
  678. if self.expand_output:
  679. self.conv_1x1_exp = TFMobileViTConvLayer(
  680. config,
  681. in_channels=config.neck_hidden_sizes[5],
  682. out_channels=config.neck_hidden_sizes[6],
  683. kernel_size=1,
  684. name="conv_1x1_exp",
  685. )
  686. self.pooler = keras.layers.GlobalAveragePooling2D(data_format="channels_first", name="pooler")
  687. def _prune_heads(self, heads_to_prune):
  688. """
  689. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  690. class PreTrainedModel
  691. """
  692. raise NotImplementedError
  693. @unpack_inputs
  694. def call(
  695. self,
  696. pixel_values: tf.Tensor | None = None,
  697. output_hidden_states: Optional[bool] = None,
  698. return_dict: Optional[bool] = None,
  699. training: bool = False,
  700. ) -> Union[Tuple[tf.Tensor], TFBaseModelOutputWithPooling]:
  701. output_hidden_states = (
  702. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  703. )
  704. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  705. # When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format.
  706. # So change the input format from `NCHW` to `NHWC`.
  707. # shape = (batch_size, in_height, in_width, in_channels=num_channels)
  708. pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
  709. embedding_output = self.conv_stem(pixel_values, training=training)
  710. encoder_outputs = self.encoder(
  711. embedding_output, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training
  712. )
  713. if self.expand_output:
  714. last_hidden_state = self.conv_1x1_exp(encoder_outputs[0])
  715. # Change to NCHW output format to have uniformity in the modules
  716. last_hidden_state = tf.transpose(last_hidden_state, perm=[0, 3, 1, 2])
  717. # global average pooling: (batch_size, channels, height, width) -> (batch_size, channels)
  718. pooled_output = self.pooler(last_hidden_state)
  719. else:
  720. last_hidden_state = encoder_outputs[0]
  721. # Change to NCHW output format to have uniformity in the modules
  722. last_hidden_state = tf.transpose(last_hidden_state, perm=[0, 3, 1, 2])
  723. pooled_output = None
  724. if not return_dict:
  725. output = (last_hidden_state, pooled_output) if pooled_output is not None else (last_hidden_state,)
  726. # Change to NCHW output format to have uniformity in the modules
  727. if not self.expand_output:
  728. remaining_encoder_outputs = encoder_outputs[1:]
  729. remaining_encoder_outputs = tuple(
  730. [tf.transpose(h, perm=(0, 3, 1, 2)) for h in remaining_encoder_outputs[0]]
  731. )
  732. remaining_encoder_outputs = (remaining_encoder_outputs,)
  733. return output + remaining_encoder_outputs
  734. else:
  735. return output + encoder_outputs[1:]
  736. # Change the other hidden state outputs to NCHW as well
  737. if output_hidden_states:
  738. hidden_states = tuple([tf.transpose(h, perm=(0, 3, 1, 2)) for h in encoder_outputs[1]])
  739. return TFBaseModelOutputWithPooling(
  740. last_hidden_state=last_hidden_state,
  741. pooler_output=pooled_output,
  742. hidden_states=hidden_states if output_hidden_states else encoder_outputs.hidden_states,
  743. )
  744. def build(self, input_shape=None):
  745. if self.built:
  746. return
  747. self.built = True
  748. if getattr(self, "conv_stem", None) is not None:
  749. with tf.name_scope(self.conv_stem.name):
  750. self.conv_stem.build(None)
  751. if getattr(self, "encoder", None) is not None:
  752. with tf.name_scope(self.encoder.name):
  753. self.encoder.build(None)
  754. if getattr(self, "pooler", None) is not None:
  755. with tf.name_scope(self.pooler.name):
  756. self.pooler.build([None, None, None, None])
  757. if getattr(self, "conv_1x1_exp", None) is not None:
  758. with tf.name_scope(self.conv_1x1_exp.name):
  759. self.conv_1x1_exp.build(None)
  760. class TFMobileViTPreTrainedModel(TFPreTrainedModel):
  761. """
  762. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  763. models.
  764. """
  765. config_class = MobileViTConfig
  766. base_model_prefix = "mobilevit"
  767. main_input_name = "pixel_values"
  768. MOBILEVIT_START_DOCSTRING = r"""
  769. This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
  770. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  771. etc.)
  772. This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
  773. as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
  774. behavior.
  775. <Tip>
  776. TensorFlow models and layers in `transformers` accept two formats as input:
  777. - having all inputs as keyword arguments (like PyTorch models), or
  778. - having all inputs as a list, tuple or dict in the first positional argument.
  779. The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
  780. and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
  781. pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
  782. format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
  783. the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
  784. positional argument:
  785. - a single Tensor with `pixel_values` only and nothing else: `model(pixel_values)`
  786. - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
  787. `model([pixel_values, attention_mask])` or `model([pixel_values, attention_mask, token_type_ids])`
  788. - a dictionary with one or several input Tensors associated to the input names given in the docstring:
  789. `model({"pixel_values": pixel_values, "token_type_ids": token_type_ids})`
  790. Note that when creating models and layers with
  791. [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
  792. about any of this, as you can just pass inputs like you would to any other Python function!
  793. </Tip>
  794. Parameters:
  795. config ([`MobileViTConfig`]): Model configuration class with all the parameters of the model.
  796. Initializing with a config file does not load the weights associated with the model, only the
  797. configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
  798. """
  799. MOBILEVIT_INPUTS_DOCSTRING = r"""
  800. Args:
  801. pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]`, `Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
  802. Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
  803. [`MobileViTImageProcessor.__call__`] for details.
  804. output_hidden_states (`bool`, *optional*):
  805. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  806. more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
  807. used instead.
  808. return_dict (`bool`, *optional*):
  809. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
  810. eager mode, in graph mode the value will always be set to True.
  811. """
  812. @add_start_docstrings(
  813. "The bare MobileViT model outputting raw hidden-states without any specific head on top.",
  814. MOBILEVIT_START_DOCSTRING,
  815. )
  816. class TFMobileViTModel(TFMobileViTPreTrainedModel):
  817. def __init__(self, config: MobileViTConfig, expand_output: bool = True, *inputs, **kwargs):
  818. super().__init__(config, *inputs, **kwargs)
  819. self.config = config
  820. self.expand_output = expand_output
  821. self.mobilevit = TFMobileViTMainLayer(config, expand_output=expand_output, name="mobilevit")
  822. @unpack_inputs
  823. @add_start_docstrings_to_model_forward(MOBILEVIT_INPUTS_DOCSTRING)
  824. @add_code_sample_docstrings(
  825. checkpoint=_CHECKPOINT_FOR_DOC,
  826. output_type=TFBaseModelOutputWithPooling,
  827. config_class=_CONFIG_FOR_DOC,
  828. modality="vision",
  829. expected_output=_EXPECTED_OUTPUT_SHAPE,
  830. )
  831. def call(
  832. self,
  833. pixel_values: tf.Tensor | None = None,
  834. output_hidden_states: Optional[bool] = None,
  835. return_dict: Optional[bool] = None,
  836. training: bool = False,
  837. ) -> Union[Tuple[tf.Tensor], TFBaseModelOutputWithPooling]:
  838. output = self.mobilevit(pixel_values, output_hidden_states, return_dict, training=training)
  839. return output
  840. def build(self, input_shape=None):
  841. if self.built:
  842. return
  843. self.built = True
  844. if getattr(self, "mobilevit", None) is not None:
  845. with tf.name_scope(self.mobilevit.name):
  846. self.mobilevit.build(None)
  847. @add_start_docstrings(
  848. """
  849. MobileViT model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
  850. ImageNet.
  851. """,
  852. MOBILEVIT_START_DOCSTRING,
  853. )
  854. class TFMobileViTForImageClassification(TFMobileViTPreTrainedModel, TFSequenceClassificationLoss):
  855. def __init__(self, config: MobileViTConfig, *inputs, **kwargs) -> None:
  856. super().__init__(config, *inputs, **kwargs)
  857. self.num_labels = config.num_labels
  858. self.mobilevit = TFMobileViTMainLayer(config, name="mobilevit")
  859. # Classifier head
  860. self.dropout = keras.layers.Dropout(config.classifier_dropout_prob)
  861. self.classifier = (
  862. keras.layers.Dense(config.num_labels, name="classifier") if config.num_labels > 0 else tf.identity
  863. )
  864. self.config = config
  865. @unpack_inputs
  866. @add_start_docstrings_to_model_forward(MOBILEVIT_INPUTS_DOCSTRING)
  867. @add_code_sample_docstrings(
  868. checkpoint=_IMAGE_CLASS_CHECKPOINT,
  869. output_type=TFImageClassifierOutputWithNoAttention,
  870. config_class=_CONFIG_FOR_DOC,
  871. expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
  872. )
  873. def call(
  874. self,
  875. pixel_values: tf.Tensor | None = None,
  876. output_hidden_states: Optional[bool] = None,
  877. labels: tf.Tensor | None = None,
  878. return_dict: Optional[bool] = None,
  879. training: Optional[bool] = False,
  880. ) -> Union[tuple, TFImageClassifierOutputWithNoAttention]:
  881. r"""
  882. labels (`tf.Tensor` of shape `(batch_size,)`, *optional*):
  883. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  884. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss). If
  885. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  886. """
  887. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  888. outputs = self.mobilevit(
  889. pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training
  890. )
  891. pooled_output = outputs.pooler_output if return_dict else outputs[1]
  892. logits = self.classifier(self.dropout(pooled_output, training=training))
  893. loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits)
  894. if not return_dict:
  895. output = (logits,) + outputs[2:]
  896. return ((loss,) + output) if loss is not None else output
  897. return TFImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
  898. def build(self, input_shape=None):
  899. if self.built:
  900. return
  901. self.built = True
  902. if getattr(self, "mobilevit", None) is not None:
  903. with tf.name_scope(self.mobilevit.name):
  904. self.mobilevit.build(None)
  905. if getattr(self, "classifier", None) is not None:
  906. if hasattr(self.classifier, "name"):
  907. with tf.name_scope(self.classifier.name):
  908. self.classifier.build([None, None, self.config.neck_hidden_sizes[-1]])
  909. class TFMobileViTASPPPooling(keras.layers.Layer):
  910. def __init__(self, config: MobileViTConfig, in_channels: int, out_channels: int, **kwargs) -> None:
  911. super().__init__(**kwargs)
  912. self.global_pool = keras.layers.GlobalAveragePooling2D(keepdims=True, name="global_pool")
  913. self.conv_1x1 = TFMobileViTConvLayer(
  914. config,
  915. in_channels=in_channels,
  916. out_channels=out_channels,
  917. kernel_size=1,
  918. stride=1,
  919. use_normalization=True,
  920. use_activation="relu",
  921. name="conv_1x1",
  922. )
  923. def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor:
  924. spatial_size = shape_list(features)[1:-1]
  925. features = self.global_pool(features)
  926. features = self.conv_1x1(features, training=training)
  927. features = tf.image.resize(features, size=spatial_size, method="bilinear")
  928. return features
  929. def build(self, input_shape=None):
  930. if self.built:
  931. return
  932. self.built = True
  933. if getattr(self, "global_pool", None) is not None:
  934. with tf.name_scope(self.global_pool.name):
  935. self.global_pool.build([None, None, None, None])
  936. if getattr(self, "conv_1x1", None) is not None:
  937. with tf.name_scope(self.conv_1x1.name):
  938. self.conv_1x1.build(None)
  939. class TFMobileViTASPP(keras.layers.Layer):
  940. """
  941. ASPP module defined in DeepLab papers: https://arxiv.org/abs/1606.00915, https://arxiv.org/abs/1706.05587
  942. """
  943. def __init__(self, config: MobileViTConfig, **kwargs) -> None:
  944. super().__init__(**kwargs)
  945. in_channels = config.neck_hidden_sizes[-2]
  946. out_channels = config.aspp_out_channels
  947. if len(config.atrous_rates) != 3:
  948. raise ValueError("Expected 3 values for atrous_rates")
  949. self.convs = []
  950. in_projection = TFMobileViTConvLayer(
  951. config,
  952. in_channels=in_channels,
  953. out_channels=out_channels,
  954. kernel_size=1,
  955. use_activation="relu",
  956. name="convs.0",
  957. )
  958. self.convs.append(in_projection)
  959. self.convs.extend(
  960. [
  961. TFMobileViTConvLayer(
  962. config,
  963. in_channels=in_channels,
  964. out_channels=out_channels,
  965. kernel_size=3,
  966. dilation=rate,
  967. use_activation="relu",
  968. name=f"convs.{i + 1}",
  969. )
  970. for i, rate in enumerate(config.atrous_rates)
  971. ]
  972. )
  973. pool_layer = TFMobileViTASPPPooling(
  974. config, in_channels, out_channels, name=f"convs.{len(config.atrous_rates) + 1}"
  975. )
  976. self.convs.append(pool_layer)
  977. self.project = TFMobileViTConvLayer(
  978. config,
  979. in_channels=5 * out_channels,
  980. out_channels=out_channels,
  981. kernel_size=1,
  982. use_activation="relu",
  983. name="project",
  984. )
  985. self.dropout = keras.layers.Dropout(config.aspp_dropout_prob)
  986. def call(self, features: tf.Tensor, training: bool = False) -> tf.Tensor:
  987. # since the hidden states were transposed to have `(batch_size, channels, height, width)`
  988. # layout we transpose them back to have `(batch_size, height, width, channels)` layout.
  989. features = tf.transpose(features, perm=[0, 2, 3, 1])
  990. pyramid = []
  991. for conv in self.convs:
  992. pyramid.append(conv(features, training=training))
  993. pyramid = tf.concat(pyramid, axis=-1)
  994. pooled_features = self.project(pyramid, training=training)
  995. pooled_features = self.dropout(pooled_features, training=training)
  996. return pooled_features
  997. def build(self, input_shape=None):
  998. if self.built:
  999. return
  1000. self.built = True
  1001. if getattr(self, "project", None) is not None:
  1002. with tf.name_scope(self.project.name):
  1003. self.project.build(None)
  1004. if getattr(self, "convs", None) is not None:
  1005. for conv in self.convs:
  1006. with tf.name_scope(conv.name):
  1007. conv.build(None)
  1008. class TFMobileViTDeepLabV3(keras.layers.Layer):
  1009. """
  1010. DeepLabv3 architecture: https://arxiv.org/abs/1706.05587
  1011. """
  1012. def __init__(self, config: MobileViTConfig, **kwargs) -> None:
  1013. super().__init__(**kwargs)
  1014. self.aspp = TFMobileViTASPP(config, name="aspp")
  1015. self.dropout = keras.layers.Dropout(config.classifier_dropout_prob)
  1016. self.classifier = TFMobileViTConvLayer(
  1017. config,
  1018. in_channels=config.aspp_out_channels,
  1019. out_channels=config.num_labels,
  1020. kernel_size=1,
  1021. use_normalization=False,
  1022. use_activation=False,
  1023. bias=True,
  1024. name="classifier",
  1025. )
  1026. def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
  1027. features = self.aspp(hidden_states[-1], training=training)
  1028. features = self.dropout(features, training=training)
  1029. features = self.classifier(features, training=training)
  1030. return features
  1031. def build(self, input_shape=None):
  1032. if self.built:
  1033. return
  1034. self.built = True
  1035. if getattr(self, "aspp", None) is not None:
  1036. with tf.name_scope(self.aspp.name):
  1037. self.aspp.build(None)
  1038. if getattr(self, "classifier", None) is not None:
  1039. with tf.name_scope(self.classifier.name):
  1040. self.classifier.build(None)
  1041. @add_start_docstrings(
  1042. """
  1043. MobileViT model with a semantic segmentation head on top, e.g. for Pascal VOC.
  1044. """,
  1045. MOBILEVIT_START_DOCSTRING,
  1046. )
  1047. class TFMobileViTForSemanticSegmentation(TFMobileViTPreTrainedModel):
  1048. def __init__(self, config: MobileViTConfig, **kwargs) -> None:
  1049. super().__init__(config, **kwargs)
  1050. self.num_labels = config.num_labels
  1051. self.mobilevit = TFMobileViTMainLayer(config, expand_output=False, name="mobilevit")
  1052. self.segmentation_head = TFMobileViTDeepLabV3(config, name="segmentation_head")
  1053. def hf_compute_loss(self, logits, labels):
  1054. # upsample logits to the images' original size
  1055. # `labels` is of shape (batch_size, height, width)
  1056. label_interp_shape = shape_list(labels)[1:]
  1057. upsampled_logits = tf.image.resize(logits, size=label_interp_shape, method="bilinear")
  1058. # compute weighted loss
  1059. loss_fct = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none")
  1060. def masked_loss(real, pred):
  1061. unmasked_loss = loss_fct(real, pred)
  1062. mask = tf.cast(real != self.config.semantic_loss_ignore_index, dtype=unmasked_loss.dtype)
  1063. masked_loss = unmasked_loss * mask
  1064. # Reduction strategy in the similar spirit with
  1065. # https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_tf_utils.py#L210
  1066. reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(mask)
  1067. return tf.reshape(reduced_masked_loss, (1,))
  1068. return masked_loss(labels, upsampled_logits)
  1069. @unpack_inputs
  1070. @add_start_docstrings_to_model_forward(MOBILEVIT_INPUTS_DOCSTRING)
  1071. @replace_return_docstrings(output_type=TFSemanticSegmenterOutputWithNoAttention, config_class=_CONFIG_FOR_DOC)
  1072. def call(
  1073. self,
  1074. pixel_values: tf.Tensor | None = None,
  1075. labels: tf.Tensor | None = None,
  1076. output_hidden_states: Optional[bool] = None,
  1077. return_dict: Optional[bool] = None,
  1078. training: bool = False,
  1079. ) -> Union[tuple, TFSemanticSegmenterOutputWithNoAttention]:
  1080. r"""
  1081. labels (`tf.Tensor` of shape `(batch_size, height, width)`, *optional*):
  1082. Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
  1083. config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
  1084. Returns:
  1085. Examples:
  1086. ```python
  1087. >>> from transformers import AutoImageProcessor, TFMobileViTForSemanticSegmentation
  1088. >>> from PIL import Image
  1089. >>> import requests
  1090. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1091. >>> image = Image.open(requests.get(url, stream=True).raw)
  1092. >>> image_processor = AutoImageProcessor.from_pretrained("apple/deeplabv3-mobilevit-small")
  1093. >>> model = TFMobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobilevit-small")
  1094. >>> inputs = image_processor(images=image, return_tensors="tf")
  1095. >>> outputs = model(**inputs)
  1096. >>> # logits are of shape (batch_size, num_labels, height, width)
  1097. >>> logits = outputs.logits
  1098. ```"""
  1099. output_hidden_states = (
  1100. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1101. )
  1102. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1103. if labels is not None and not self.config.num_labels > 1:
  1104. raise ValueError("The number of labels should be greater than one")
  1105. outputs = self.mobilevit(
  1106. pixel_values,
  1107. output_hidden_states=True, # we need the intermediate hidden states
  1108. return_dict=return_dict,
  1109. training=training,
  1110. )
  1111. encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]
  1112. logits = self.segmentation_head(encoder_hidden_states, training=training)
  1113. loss = None
  1114. if labels is not None:
  1115. loss = self.hf_compute_loss(logits=logits, labels=labels)
  1116. # make logits of shape (batch_size, num_labels, height, width) to
  1117. # keep them consistent across APIs
  1118. logits = tf.transpose(logits, perm=[0, 3, 1, 2])
  1119. if not return_dict:
  1120. if output_hidden_states:
  1121. output = (logits,) + outputs[1:]
  1122. else:
  1123. output = (logits,) + outputs[2:]
  1124. return ((loss,) + output) if loss is not None else output
  1125. return TFSemanticSegmenterOutputWithNoAttention(
  1126. loss=loss,
  1127. logits=logits,
  1128. hidden_states=outputs.hidden_states if output_hidden_states else None,
  1129. )
  1130. def build(self, input_shape=None):
  1131. if self.built:
  1132. return
  1133. self.built = True
  1134. if getattr(self, "mobilevit", None) is not None:
  1135. with tf.name_scope(self.mobilevit.name):
  1136. self.mobilevit.build(None)
  1137. if getattr(self, "segmentation_head", None) is not None:
  1138. with tf.name_scope(self.segmentation_head.name):
  1139. self.segmentation_head.build(None)