backbone_utils.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. # coding=utf-8
  2. # Copyright 2023 The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Collection of utils to be used by backbones and their components."""
  16. import enum
  17. import inspect
  18. from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union
  19. if TYPE_CHECKING:
  20. from ..configuration_utils import PretrainedConfig
  21. class BackboneType(enum.Enum):
  22. TIMM = "timm"
  23. TRANSFORMERS = "transformers"
  24. def verify_out_features_out_indices(
  25. out_features: Optional[Iterable[str]], out_indices: Optional[Iterable[int]], stage_names: Optional[Iterable[str]]
  26. ):
  27. """
  28. Verify that out_indices and out_features are valid for the given stage_names.
  29. """
  30. if stage_names is None:
  31. raise ValueError("Stage_names must be set for transformers backbones")
  32. if out_features is not None:
  33. if not isinstance(out_features, (list,)):
  34. raise ValueError(f"out_features must be a list got {type(out_features)}")
  35. if any(feat not in stage_names for feat in out_features):
  36. raise ValueError(f"out_features must be a subset of stage_names: {stage_names} got {out_features}")
  37. if len(out_features) != len(set(out_features)):
  38. raise ValueError(f"out_features must not contain any duplicates, got {out_features}")
  39. if out_features != (sorted_feats := [feat for feat in stage_names if feat in out_features]):
  40. raise ValueError(
  41. f"out_features must be in the same order as stage_names, expected {sorted_feats} got {out_features}"
  42. )
  43. if out_indices is not None:
  44. if not isinstance(out_indices, list):
  45. raise ValueError(f"out_indices must be a list, got {type(out_indices)}")
  46. # Convert negative indices to their positive equivalent: [-1,] -> [len(stage_names) - 1,]
  47. positive_indices = tuple(idx % len(stage_names) if idx < 0 else idx for idx in out_indices)
  48. if any(idx for idx in positive_indices if idx not in range(len(stage_names))):
  49. raise ValueError(f"out_indices must be valid indices for stage_names {stage_names}, got {out_indices}")
  50. if len(positive_indices) != len(set(positive_indices)):
  51. msg = f"out_indices must not contain any duplicates, got {out_indices}"
  52. msg += f"(equivalent to {positive_indices}))" if positive_indices != out_indices else ""
  53. raise ValueError(msg)
  54. if positive_indices != tuple(sorted(positive_indices)):
  55. sorted_negative = [idx for _, idx in sorted(zip(positive_indices, out_indices), key=lambda x: x[0])]
  56. raise ValueError(
  57. f"out_indices must be in the same order as stage_names, expected {sorted_negative} got {out_indices}"
  58. )
  59. if out_features is not None and out_indices is not None:
  60. if len(out_features) != len(out_indices):
  61. raise ValueError("out_features and out_indices should have the same length if both are set")
  62. if out_features != [stage_names[idx] for idx in out_indices]:
  63. raise ValueError("out_features and out_indices should correspond to the same stages if both are set")
  64. def _align_output_features_output_indices(
  65. out_features: Optional[List[str]],
  66. out_indices: Optional[Union[List[int], Tuple[int]]],
  67. stage_names: List[str],
  68. ):
  69. """
  70. Finds the corresponding `out_features` and `out_indices` for the given `stage_names`.
  71. The logic is as follows:
  72. - `out_features` not set, `out_indices` set: `out_features` is set to the `out_features` corresponding to the
  73. `out_indices`.
  74. - `out_indices` not set, `out_features` set: `out_indices` is set to the `out_indices` corresponding to the
  75. `out_features`.
  76. - `out_indices` and `out_features` not set: `out_indices` and `out_features` are set to the last stage.
  77. - `out_indices` and `out_features` set: input `out_indices` and `out_features` are returned.
  78. Args:
  79. out_features (`List[str]`): The names of the features for the backbone to output.
  80. out_indices (`List[int]` or `Tuple[int]`): The indices of the features for the backbone to output.
  81. stage_names (`List[str]`): The names of the stages of the backbone.
  82. """
  83. if out_indices is None and out_features is None:
  84. out_indices = [len(stage_names) - 1]
  85. out_features = [stage_names[-1]]
  86. elif out_indices is None and out_features is not None:
  87. out_indices = [stage_names.index(layer) for layer in out_features]
  88. elif out_features is None and out_indices is not None:
  89. out_features = [stage_names[idx] for idx in out_indices]
  90. return out_features, out_indices
  91. def get_aligned_output_features_output_indices(
  92. out_features: Optional[List[str]],
  93. out_indices: Optional[Union[List[int], Tuple[int]]],
  94. stage_names: List[str],
  95. ) -> Tuple[List[str], List[int]]:
  96. """
  97. Get the `out_features` and `out_indices` so that they are aligned.
  98. The logic is as follows:
  99. - `out_features` not set, `out_indices` set: `out_features` is set to the `out_features` corresponding to the
  100. `out_indices`.
  101. - `out_indices` not set, `out_features` set: `out_indices` is set to the `out_indices` corresponding to the
  102. `out_features`.
  103. - `out_indices` and `out_features` not set: `out_indices` and `out_features` are set to the last stage.
  104. - `out_indices` and `out_features` set: they are verified to be aligned.
  105. Args:
  106. out_features (`List[str]`): The names of the features for the backbone to output.
  107. out_indices (`List[int]` or `Tuple[int]`): The indices of the features for the backbone to output.
  108. stage_names (`List[str]`): The names of the stages of the backbone.
  109. """
  110. out_indices = list(out_indices) if out_indices is not None else None
  111. # First verify that the out_features and out_indices are valid
  112. verify_out_features_out_indices(out_features=out_features, out_indices=out_indices, stage_names=stage_names)
  113. output_features, output_indices = _align_output_features_output_indices(
  114. out_features=out_features, out_indices=out_indices, stage_names=stage_names
  115. )
  116. # Verify that the aligned out_features and out_indices are valid
  117. verify_out_features_out_indices(out_features=output_features, out_indices=output_indices, stage_names=stage_names)
  118. return output_features, output_indices
  119. class BackboneMixin:
  120. backbone_type: Optional[BackboneType] = None
  121. def _init_timm_backbone(self, config) -> None:
  122. """
  123. Initialize the backbone model from timm The backbone must already be loaded to self._backbone
  124. """
  125. if getattr(self, "_backbone", None) is None:
  126. raise ValueError("self._backbone must be set before calling _init_timm_backbone")
  127. # These will diagree with the defaults for the transformers models e.g. for resnet50
  128. # the transformer model has out_features = ['stem', 'stage1', 'stage2', 'stage3', 'stage4']
  129. # the timm model has out_features = ['act', 'layer1', 'layer2', 'layer3', 'layer4']
  130. self.stage_names = [stage["module"] for stage in self._backbone.feature_info.info]
  131. self.num_features = [stage["num_chs"] for stage in self._backbone.feature_info.info]
  132. # In some timm versions, out_indices reflects the input type of out_indices on the `create_model` call,
  133. # in later versions >= 1, it is always a tuple
  134. out_indices = list(self._backbone.feature_info.out_indices)
  135. out_features = self._backbone.feature_info.module_name()
  136. # We verify the out indices and out features are valid
  137. verify_out_features_out_indices(
  138. out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
  139. )
  140. self._out_features, self._out_indices = out_features, out_indices
  141. def _init_transformers_backbone(self, config) -> None:
  142. stage_names = getattr(config, "stage_names")
  143. out_features = getattr(config, "out_features", None)
  144. out_indices = getattr(config, "out_indices", None)
  145. self.stage_names = stage_names
  146. self._out_features, self._out_indices = get_aligned_output_features_output_indices(
  147. out_features=out_features, out_indices=out_indices, stage_names=stage_names
  148. )
  149. # Number of channels for each stage. This is set in the transformer backbone model init
  150. self.num_features = None
  151. def _init_backbone(self, config) -> None:
  152. """
  153. Method to initialize the backbone. This method is called by the constructor of the base class after the
  154. pretrained model weights have been loaded.
  155. """
  156. self.config = config
  157. self.use_timm_backbone = getattr(config, "use_timm_backbone", False)
  158. self.backbone_type = BackboneType.TIMM if self.use_timm_backbone else BackboneType.TRANSFORMERS
  159. if self.backbone_type == BackboneType.TIMM:
  160. self._init_timm_backbone(config)
  161. elif self.backbone_type == BackboneType.TRANSFORMERS:
  162. self._init_transformers_backbone(config)
  163. else:
  164. raise ValueError(f"backbone_type {self.backbone_type} not supported.")
  165. @property
  166. def out_features(self):
  167. return self._out_features
  168. @out_features.setter
  169. def out_features(self, out_features: List[str]):
  170. """
  171. Set the out_features attribute. This will also update the out_indices attribute to match the new out_features.
  172. """
  173. self._out_features, self._out_indices = get_aligned_output_features_output_indices(
  174. out_features=out_features, out_indices=None, stage_names=self.stage_names
  175. )
  176. @property
  177. def out_indices(self):
  178. return self._out_indices
  179. @out_indices.setter
  180. def out_indices(self, out_indices: Union[Tuple[int], List[int]]):
  181. """
  182. Set the out_indices attribute. This will also update the out_features attribute to match the new out_indices.
  183. """
  184. self._out_features, self._out_indices = get_aligned_output_features_output_indices(
  185. out_features=None, out_indices=out_indices, stage_names=self.stage_names
  186. )
  187. @property
  188. def out_feature_channels(self):
  189. # the current backbones will output the number of channels for each stage
  190. # even if that stage is not in the out_features list.
  191. return {stage: self.num_features[i] for i, stage in enumerate(self.stage_names)}
  192. @property
  193. def channels(self):
  194. return [self.out_feature_channels[name] for name in self.out_features]
  195. def forward_with_filtered_kwargs(self, *args, **kwargs):
  196. signature = dict(inspect.signature(self.forward).parameters)
  197. filtered_kwargs = {k: v for k, v in kwargs.items() if k in signature}
  198. return self(*args, **filtered_kwargs)
  199. def forward(
  200. self,
  201. pixel_values,
  202. output_hidden_states: Optional[bool] = None,
  203. output_attentions: Optional[bool] = None,
  204. return_dict: Optional[bool] = None,
  205. ):
  206. raise NotImplementedError("This method should be implemented by the derived class.")
  207. def to_dict(self):
  208. """
  209. Serializes this instance to a Python dictionary. Override the default `to_dict()` from `PretrainedConfig` to
  210. include the `out_features` and `out_indices` attributes.
  211. """
  212. output = super().to_dict()
  213. output["out_features"] = output.pop("_out_features")
  214. output["out_indices"] = output.pop("_out_indices")
  215. return output
  216. class BackboneConfigMixin:
  217. """
  218. A Mixin to support handling the `out_features` and `out_indices` attributes for the backbone configurations.
  219. """
  220. @property
  221. def out_features(self):
  222. return self._out_features
  223. @out_features.setter
  224. def out_features(self, out_features: List[str]):
  225. """
  226. Set the out_features attribute. This will also update the out_indices attribute to match the new out_features.
  227. """
  228. self._out_features, self._out_indices = get_aligned_output_features_output_indices(
  229. out_features=out_features, out_indices=None, stage_names=self.stage_names
  230. )
  231. @property
  232. def out_indices(self):
  233. return self._out_indices
  234. @out_indices.setter
  235. def out_indices(self, out_indices: Union[Tuple[int], List[int]]):
  236. """
  237. Set the out_indices attribute. This will also update the out_features attribute to match the new out_indices.
  238. """
  239. self._out_features, self._out_indices = get_aligned_output_features_output_indices(
  240. out_features=None, out_indices=out_indices, stage_names=self.stage_names
  241. )
  242. def to_dict(self):
  243. """
  244. Serializes this instance to a Python dictionary. Override the default `to_dict()` from `PretrainedConfig` to
  245. include the `out_features` and `out_indices` attributes.
  246. """
  247. output = super().to_dict()
  248. output["out_features"] = output.pop("_out_features")
  249. output["out_indices"] = output.pop("_out_indices")
  250. return output
  251. def load_backbone(config):
  252. """
  253. Loads the backbone model from a config object.
  254. If the config is from the backbone model itself, then we return a backbone model with randomly initialized
  255. weights.
  256. If the config is from the parent model of the backbone model itself, then we load the pretrained backbone weights
  257. if specified.
  258. """
  259. from transformers import AutoBackbone, AutoConfig
  260. backbone_config = getattr(config, "backbone_config", None)
  261. use_timm_backbone = getattr(config, "use_timm_backbone", None)
  262. use_pretrained_backbone = getattr(config, "use_pretrained_backbone", None)
  263. backbone_checkpoint = getattr(config, "backbone", None)
  264. backbone_kwargs = getattr(config, "backbone_kwargs", None)
  265. backbone_kwargs = {} if backbone_kwargs is None else backbone_kwargs
  266. if backbone_kwargs and backbone_config is not None:
  267. raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
  268. # If there is a backbone_config and a backbone checkpoint, and use_pretrained_backbone=False then the desired
  269. # behaviour is ill-defined: do you want to load from the checkpoint's config or the backbone_config?
  270. if backbone_config is not None and backbone_checkpoint is not None and use_pretrained_backbone is not None:
  271. raise ValueError("Cannot specify both config.backbone_config and config.backbone")
  272. # If any of thhe following are set, then the config passed in is from a model which contains a backbone.
  273. if (
  274. backbone_config is None
  275. and use_timm_backbone is None
  276. and backbone_checkpoint is None
  277. and backbone_checkpoint is None
  278. ):
  279. return AutoBackbone.from_config(config=config, **backbone_kwargs)
  280. # config from the parent model that has a backbone
  281. if use_timm_backbone:
  282. if backbone_checkpoint is None:
  283. raise ValueError("config.backbone must be set if use_timm_backbone is True")
  284. # Because of how timm backbones were originally added to models, we need to pass in use_pretrained_backbone
  285. # to determine whether to load the pretrained weights.
  286. backbone = AutoBackbone.from_pretrained(
  287. backbone_checkpoint,
  288. use_timm_backbone=use_timm_backbone,
  289. use_pretrained_backbone=use_pretrained_backbone,
  290. **backbone_kwargs,
  291. )
  292. elif use_pretrained_backbone:
  293. if backbone_checkpoint is None:
  294. raise ValueError("config.backbone must be set if use_pretrained_backbone is True")
  295. backbone = AutoBackbone.from_pretrained(backbone_checkpoint, **backbone_kwargs)
  296. else:
  297. if backbone_config is None and backbone_checkpoint is None:
  298. raise ValueError("Either config.backbone_config or config.backbone must be set")
  299. if backbone_config is None:
  300. backbone_config = AutoConfig.from_pretrained(backbone_checkpoint, **backbone_kwargs)
  301. backbone = AutoBackbone.from_config(config=backbone_config)
  302. return backbone
  303. def verify_backbone_config_arguments(
  304. use_timm_backbone: bool,
  305. use_pretrained_backbone: bool,
  306. backbone: Optional[str],
  307. backbone_config: Optional[Union[dict, "PretrainedConfig"]],
  308. backbone_kwargs: Optional[dict],
  309. ):
  310. """
  311. Verify that the config arguments to be passed to load_backbone are valid
  312. """
  313. if backbone_config is not None and backbone is not None:
  314. raise ValueError("You can't specify both `backbone` and `backbone_config`.")
  315. if backbone_config is not None and use_timm_backbone:
  316. raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")
  317. if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
  318. raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")