modeling_tf_pytorch_utils.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675
  1. # coding=utf-8
  2. # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
  3. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """PyTorch - TF 2.0 general utilities."""
  17. import os
  18. import re
  19. import numpy
  20. from .utils import (
  21. ExplicitEnum,
  22. expand_dims,
  23. is_numpy_array,
  24. is_safetensors_available,
  25. is_torch_tensor,
  26. logging,
  27. reshape,
  28. squeeze,
  29. tensor_size,
  30. )
  31. from .utils import transpose as transpose_func
  32. if is_safetensors_available():
  33. from safetensors import safe_open
  34. logger = logging.get_logger(__name__)
  35. class TransposeType(ExplicitEnum):
  36. """
  37. Possible ...
  38. """
  39. NO = "no"
  40. SIMPLE = "simple"
  41. CONV1D = "conv1d"
  42. CONV2D = "conv2d"
  43. def convert_tf_weight_name_to_pt_weight_name(
  44. tf_name, start_prefix_to_remove="", tf_weight_shape=None, name_scope=None
  45. ):
  46. """
  47. Convert a TF 2.0 model variable name in a pytorch model weight name.
  48. Conventions for TF2.0 scopes -> PyTorch attribute names conversions:
  49. - '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
  50. - '_._' is replaced by a new level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
  51. return tuple with:
  52. - pytorch model weight name
  53. - transpose: `TransposeType` member indicating whether and how TF2.0 and PyTorch weights matrices should be
  54. transposed with regards to each other
  55. """
  56. if name_scope is not None:
  57. if not tf_name.startswith(name_scope) and "final_logits_bias" not in tf_name:
  58. raise ValueError(
  59. f"Weight name {tf_name} does not start with name_scope {name_scope}. This is an internal error "
  60. "in Transformers, so (unless you were doing something really evil) please open an issue to report it!"
  61. )
  62. tf_name = tf_name[len(name_scope) :]
  63. tf_name = tf_name.lstrip("/")
  64. tf_name = tf_name.replace(":0", "") # device ids
  65. tf_name = re.sub(
  66. r"/[^/]*___([^/]*)/", r"/\1/", tf_name
  67. ) # '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
  68. tf_name = tf_name.replace(
  69. "_._", "/"
  70. ) # '_._' is replaced by a level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
  71. tf_name = re.sub(r"//+", "/", tf_name) # Remove empty levels at the end
  72. tf_name = tf_name.split("/") # Convert from TF2.0 '/' separators to PyTorch '.' separators
  73. # Some weights have a single name without "/" such as final_logits_bias in BART
  74. if len(tf_name) > 1:
  75. tf_name = tf_name[1:] # Remove level zero
  76. tf_weight_shape = list(tf_weight_shape)
  77. # When should we transpose the weights
  78. if tf_name[-1] == "kernel" and tf_weight_shape is not None and len(tf_weight_shape) == 4:
  79. transpose = TransposeType.CONV2D
  80. elif tf_name[-1] == "kernel" and tf_weight_shape is not None and len(tf_weight_shape) == 3:
  81. transpose = TransposeType.CONV1D
  82. elif bool(
  83. tf_name[-1] in ["kernel", "pointwise_kernel", "depthwise_kernel"]
  84. or "emb_projs" in tf_name
  85. or "out_projs" in tf_name
  86. ):
  87. transpose = TransposeType.SIMPLE
  88. else:
  89. transpose = TransposeType.NO
  90. # Convert standard TF2.0 names in PyTorch names
  91. if tf_name[-1] == "kernel" or tf_name[-1] == "embeddings" or tf_name[-1] == "gamma":
  92. tf_name[-1] = "weight"
  93. if tf_name[-1] == "beta":
  94. tf_name[-1] = "bias"
  95. # The SeparableConv1D TF layer contains two weights that are translated to PyTorch Conv1D here
  96. if tf_name[-1] == "pointwise_kernel" or tf_name[-1] == "depthwise_kernel":
  97. tf_name[-1] = tf_name[-1].replace("_kernel", ".weight")
  98. # Remove prefix if needed
  99. tf_name = ".".join(tf_name)
  100. if start_prefix_to_remove:
  101. tf_name = tf_name.replace(start_prefix_to_remove, "", 1)
  102. return tf_name, transpose
  103. def apply_transpose(transpose: TransposeType, weight, match_shape=None, pt_to_tf=True):
  104. """
  105. Apply a transpose to some weight then tries to reshape the weight to the same shape as a given shape, all in a
  106. framework agnostic way.
  107. """
  108. if transpose is TransposeType.CONV2D:
  109. # Conv2D weight:
  110. # PT: (num_out_channel, num_in_channel, kernel[0], kernel[1])
  111. # -> TF: (kernel[0], kernel[1], num_in_channel, num_out_channel)
  112. axes = (2, 3, 1, 0) if pt_to_tf else (3, 2, 0, 1)
  113. weight = transpose_func(weight, axes=axes)
  114. elif transpose is TransposeType.CONV1D:
  115. # Conv1D weight:
  116. # PT: (num_out_channel, num_in_channel, kernel)
  117. # -> TF: (kernel, num_in_channel, num_out_channel)
  118. weight = transpose_func(weight, axes=(2, 1, 0))
  119. elif transpose is TransposeType.SIMPLE:
  120. weight = transpose_func(weight)
  121. if match_shape is None:
  122. return weight
  123. if len(match_shape) < len(weight.shape):
  124. weight = squeeze(weight)
  125. elif len(match_shape) > len(weight.shape):
  126. weight = expand_dims(weight, axis=0)
  127. if list(match_shape) != list(weight.shape):
  128. try:
  129. weight = reshape(weight, match_shape)
  130. except AssertionError as e:
  131. e.args += (match_shape, match_shape)
  132. raise e
  133. return weight
  134. #####################
  135. # PyTorch => TF 2.0 #
  136. #####################
  137. def load_pytorch_checkpoint_in_tf2_model(
  138. tf_model,
  139. pytorch_checkpoint_path,
  140. tf_inputs=None,
  141. allow_missing_keys=False,
  142. output_loading_info=False,
  143. _prefix=None,
  144. tf_to_pt_weight_rename=None,
  145. ):
  146. """Load pytorch checkpoints in a TF 2.0 model"""
  147. try:
  148. import tensorflow as tf # noqa: F401
  149. import torch # noqa: F401
  150. from safetensors.torch import load_file as safe_load_file # noqa: F401
  151. from .pytorch_utils import is_torch_greater_or_equal_than_1_13 # noqa: F401
  152. except ImportError:
  153. logger.error(
  154. "Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
  155. "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
  156. )
  157. raise
  158. # Treats a single file as a collection of shards with 1 shard.
  159. if isinstance(pytorch_checkpoint_path, str):
  160. pytorch_checkpoint_path = [pytorch_checkpoint_path]
  161. # Loads all shards into a single state dictionary
  162. pt_state_dict = {}
  163. for path in pytorch_checkpoint_path:
  164. pt_path = os.path.abspath(path)
  165. logger.info(f"Loading PyTorch weights from {pt_path}")
  166. if pt_path.endswith(".safetensors"):
  167. state_dict = safe_load_file(pt_path)
  168. else:
  169. weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
  170. state_dict = torch.load(pt_path, map_location="cpu", **weights_only_kwarg)
  171. pt_state_dict.update(state_dict)
  172. logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters")
  173. return load_pytorch_weights_in_tf2_model(
  174. tf_model,
  175. pt_state_dict,
  176. tf_inputs=tf_inputs,
  177. allow_missing_keys=allow_missing_keys,
  178. output_loading_info=output_loading_info,
  179. _prefix=_prefix,
  180. tf_to_pt_weight_rename=tf_to_pt_weight_rename,
  181. )
  182. def load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=None, allow_missing_keys=False):
  183. """Load pytorch checkpoints in a TF 2.0 model"""
  184. pt_state_dict = pt_model.state_dict()
  185. return load_pytorch_weights_in_tf2_model(
  186. tf_model, pt_state_dict, tf_inputs=tf_inputs, allow_missing_keys=allow_missing_keys
  187. )
  188. def load_pytorch_weights_in_tf2_model(
  189. tf_model,
  190. pt_state_dict,
  191. tf_inputs=None,
  192. allow_missing_keys=False,
  193. output_loading_info=False,
  194. _prefix=None,
  195. tf_to_pt_weight_rename=None,
  196. ):
  197. """Load pytorch state_dict in a TF 2.0 model."""
  198. try:
  199. import tensorflow as tf # noqa: F401
  200. import torch # noqa: F401
  201. except ImportError:
  202. logger.error(
  203. "Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
  204. "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
  205. )
  206. raise
  207. # Numpy doesn't understand bfloat16, so upcast to a dtype that doesn't lose precision
  208. pt_state_dict = {
  209. k: v.numpy() if v.dtype != torch.bfloat16 else v.float().numpy() for k, v in pt_state_dict.items()
  210. }
  211. return load_pytorch_state_dict_in_tf2_model(
  212. tf_model,
  213. pt_state_dict,
  214. tf_inputs=tf_inputs,
  215. allow_missing_keys=allow_missing_keys,
  216. output_loading_info=output_loading_info,
  217. _prefix=_prefix,
  218. tf_to_pt_weight_rename=tf_to_pt_weight_rename,
  219. )
  220. def _log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name):
  221. if len(unexpected_keys) > 0:
  222. logger.warning(
  223. "Some weights of the PyTorch model were not used when initializing the TF 2.0 model"
  224. f" {class_name}: {unexpected_keys}\n- This IS expected if you are initializing"
  225. f" {class_name} from a PyTorch model trained on another task or with another architecture"
  226. " (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).\n- This IS"
  227. f" NOT expected if you are initializing {class_name} from a PyTorch model that you expect"
  228. " to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a"
  229. " BertForSequenceClassification model)."
  230. )
  231. else:
  232. logger.warning(f"All PyTorch model weights were used when initializing {class_name}.\n")
  233. if len(missing_keys) > 0:
  234. logger.warning(
  235. f"Some weights or buffers of the TF 2.0 model {class_name} were not initialized from the"
  236. f" PyTorch model and are newly initialized: {missing_keys}\nYou should probably TRAIN this model on a"
  237. " down-stream task to be able to use it for predictions and inference."
  238. )
  239. else:
  240. logger.warning(
  241. f"All the weights of {class_name} were initialized from the PyTorch model.\n"
  242. "If your task is similar to the task the model of the checkpoint was trained on, "
  243. f"you can already use {class_name} for predictions without further training."
  244. )
  245. if len(mismatched_keys) > 0:
  246. mismatched_warning = "\n".join(
  247. [
  248. f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
  249. for key, shape1, shape2 in mismatched_keys
  250. ]
  251. )
  252. logger.warning(
  253. f"Some weights of {class_name} were not initialized from the model checkpoint"
  254. f" are newly initialized because the shapes did not"
  255. f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
  256. " to use it for predictions and inference."
  257. )
  258. def load_pytorch_state_dict_in_tf2_model(
  259. tf_model,
  260. pt_state_dict,
  261. tf_inputs=None,
  262. allow_missing_keys=False,
  263. output_loading_info=False,
  264. _prefix=None,
  265. tf_to_pt_weight_rename=None,
  266. ignore_mismatched_sizes=False,
  267. skip_logger_warnings=False,
  268. ):
  269. """Load a pytorch state_dict in a TF 2.0 model. pt_state_dict can be either an actual dict or a lazy-loading
  270. safetensors archive created with the safe_open() function."""
  271. import tensorflow as tf
  272. if tf_inputs is None:
  273. tf_inputs = tf_model.dummy_inputs
  274. if _prefix is None:
  275. _prefix = ""
  276. if tf_inputs:
  277. with tf.name_scope(_prefix):
  278. tf_model(tf_inputs, training=False) # Make sure model is built
  279. # Convert old format to new format if needed from a PyTorch state_dict
  280. tf_keys_to_pt_keys = {}
  281. for key in pt_state_dict.keys():
  282. new_key = None
  283. if "gamma" in key:
  284. new_key = key.replace("gamma", "weight")
  285. if "beta" in key:
  286. new_key = key.replace("beta", "bias")
  287. if "running_var" in key:
  288. new_key = key.replace("running_var", "moving_variance")
  289. if "running_mean" in key:
  290. new_key = key.replace("running_mean", "moving_mean")
  291. # New `weight_norm` from https://github.com/huggingface/transformers/pull/24030
  292. key_components = key.split(".")
  293. name = None
  294. if key_components[-3::2] == ["parametrizations", "original0"]:
  295. name = key_components[-2] + "_g"
  296. elif key_components[-3::2] == ["parametrizations", "original1"]:
  297. name = key_components[-2] + "_v"
  298. if name is not None:
  299. key_components = key_components[:-3] + [name]
  300. new_key = ".".join(key_components)
  301. if new_key is None:
  302. new_key = key
  303. tf_keys_to_pt_keys[new_key] = key
  304. # Matt: All TF models store the actual model stem in a MainLayer class, including the base model.
  305. # In PT, the derived models (with heads) use the base model class as the stem instead,
  306. # and there is no MainLayer class. This means that TF base classes have one
  307. # extra layer in their weight names, corresponding to the MainLayer class. This code block compensates for that.
  308. start_prefix_to_remove = ""
  309. if not any(s.startswith(tf_model.base_model_prefix) for s in tf_keys_to_pt_keys.keys()):
  310. start_prefix_to_remove = tf_model.base_model_prefix + "."
  311. symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights
  312. tf_loaded_numel = 0
  313. all_pytorch_weights = set(tf_keys_to_pt_keys.keys())
  314. missing_keys = []
  315. mismatched_keys = []
  316. is_safetensor_archive = hasattr(pt_state_dict, "get_tensor")
  317. for symbolic_weight in symbolic_weights:
  318. sw_name = symbolic_weight.name
  319. name, transpose = convert_tf_weight_name_to_pt_weight_name(
  320. sw_name,
  321. start_prefix_to_remove=start_prefix_to_remove,
  322. tf_weight_shape=symbolic_weight.shape,
  323. name_scope=_prefix,
  324. )
  325. if tf_to_pt_weight_rename is not None:
  326. aliases = tf_to_pt_weight_rename(name) # Is a tuple to account for possible name aliasing
  327. for alias in aliases: # The aliases are in priority order, take the first one that matches
  328. if alias in tf_keys_to_pt_keys:
  329. name = alias
  330. break
  331. else:
  332. # If none of the aliases match, just use the first one (it'll be reported as missing)
  333. name = aliases[0]
  334. # Find associated numpy array in pytorch model state dict
  335. if name not in tf_keys_to_pt_keys:
  336. if allow_missing_keys:
  337. missing_keys.append(name)
  338. continue
  339. elif tf_model._keys_to_ignore_on_load_missing is not None:
  340. # authorized missing keys don't have to be loaded
  341. if any(re.search(pat, name) is not None for pat in tf_model._keys_to_ignore_on_load_missing):
  342. continue
  343. raise AttributeError(f"{name} not found in PyTorch model")
  344. state_dict_name = tf_keys_to_pt_keys[name]
  345. if is_safetensor_archive:
  346. array = pt_state_dict.get_tensor(state_dict_name)
  347. else:
  348. array = pt_state_dict[state_dict_name]
  349. try:
  350. array = apply_transpose(transpose, array, symbolic_weight.shape)
  351. except tf.errors.InvalidArgumentError as e:
  352. if not ignore_mismatched_sizes:
  353. error_msg = str(e)
  354. error_msg += (
  355. "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
  356. )
  357. raise tf.errors.InvalidArgumentError(error_msg)
  358. else:
  359. mismatched_keys.append((name, array.shape, symbolic_weight.shape))
  360. continue
  361. tf_loaded_numel += tensor_size(array)
  362. symbolic_weight.assign(tf.cast(array, symbolic_weight.dtype))
  363. del array # Immediately free memory to keep peak usage as low as possible
  364. all_pytorch_weights.discard(name)
  365. logger.info(f"Loaded {tf_loaded_numel:,} parameters in the TF 2.0 model.")
  366. unexpected_keys = list(all_pytorch_weights)
  367. if tf_model._keys_to_ignore_on_load_missing is not None:
  368. for pat in tf_model._keys_to_ignore_on_load_missing:
  369. missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
  370. if tf_model._keys_to_ignore_on_load_unexpected is not None:
  371. for pat in tf_model._keys_to_ignore_on_load_unexpected:
  372. unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
  373. if not skip_logger_warnings:
  374. _log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name=tf_model.__class__.__name__)
  375. if output_loading_info:
  376. loading_info = {
  377. "missing_keys": missing_keys,
  378. "unexpected_keys": unexpected_keys,
  379. "mismatched_keys": mismatched_keys,
  380. }
  381. return tf_model, loading_info
  382. return tf_model
  383. def load_sharded_pytorch_safetensors_in_tf2_model(
  384. tf_model,
  385. safetensors_shards,
  386. tf_inputs=None,
  387. allow_missing_keys=False,
  388. output_loading_info=False,
  389. _prefix=None,
  390. tf_to_pt_weight_rename=None,
  391. ignore_mismatched_sizes=False,
  392. ):
  393. all_loading_infos = []
  394. for shard in safetensors_shards:
  395. with safe_open(shard, framework="tf") as safetensors_archive:
  396. tf_model, loading_info = load_pytorch_state_dict_in_tf2_model(
  397. tf_model,
  398. safetensors_archive,
  399. tf_inputs=tf_inputs,
  400. allow_missing_keys=allow_missing_keys,
  401. output_loading_info=True,
  402. _prefix=_prefix,
  403. tf_to_pt_weight_rename=tf_to_pt_weight_rename,
  404. ignore_mismatched_sizes=ignore_mismatched_sizes,
  405. skip_logger_warnings=True, # We will emit merged warnings at the end
  406. )
  407. all_loading_infos.append(loading_info)
  408. # Now we just need to merge the loading info
  409. # Keys are missing only if they're missing in *every* shard
  410. missing_keys = sorted(set.intersection(*[set(info["missing_keys"]) for info in all_loading_infos]))
  411. # Keys are unexpected/mismatched if they're unexpected/mismatched in *any* shard
  412. unexpected_keys = sum([info["unexpected_keys"] for info in all_loading_infos], [])
  413. mismatched_keys = sum([info["mismatched_keys"] for info in all_loading_infos], [])
  414. _log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name=tf_model.__class__.__name__)
  415. if output_loading_info:
  416. loading_info = {
  417. "missing_keys": missing_keys,
  418. "unexpected_keys": unexpected_keys,
  419. "mismatched_keys": mismatched_keys,
  420. }
  421. return tf_model, loading_info
  422. return tf_model
  423. #####################
  424. # TF 2.0 => PyTorch #
  425. #####################
  426. def load_tf2_checkpoint_in_pytorch_model(
  427. pt_model, tf_checkpoint_path, tf_inputs=None, allow_missing_keys=False, output_loading_info=False
  428. ):
  429. """
  430. Load TF 2.0 HDF5 checkpoint in a PyTorch model We use HDF5 to easily do transfer learning (see
  431. https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
  432. """
  433. try:
  434. import tensorflow as tf # noqa: F401
  435. import torch # noqa: F401
  436. except ImportError:
  437. logger.error(
  438. "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
  439. "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
  440. )
  441. raise
  442. import transformers
  443. from .modeling_tf_utils import load_tf_weights
  444. logger.info(f"Loading TensorFlow weights from {tf_checkpoint_path}")
  445. # Instantiate and load the associated TF 2.0 model
  446. tf_model_class_name = "TF" + pt_model.__class__.__name__ # Add "TF" at the beginning
  447. tf_model_class = getattr(transformers, tf_model_class_name)
  448. tf_model = tf_model_class(pt_model.config)
  449. if tf_inputs is None:
  450. tf_inputs = tf_model.dummy_inputs
  451. if tf_inputs is not None:
  452. tf_model(tf_inputs, training=False) # Make sure model is built
  453. load_tf_weights(tf_model, tf_checkpoint_path)
  454. return load_tf2_model_in_pytorch_model(
  455. pt_model, tf_model, allow_missing_keys=allow_missing_keys, output_loading_info=output_loading_info
  456. )
  457. def load_tf2_model_in_pytorch_model(pt_model, tf_model, allow_missing_keys=False, output_loading_info=False):
  458. """Load TF 2.0 model in a pytorch model"""
  459. weights = tf_model.weights
  460. return load_tf2_weights_in_pytorch_model(
  461. pt_model, weights, allow_missing_keys=allow_missing_keys, output_loading_info=output_loading_info
  462. )
  463. def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=False, output_loading_info=False):
  464. """Load TF2.0 symbolic weights in a PyTorch model"""
  465. try:
  466. import tensorflow as tf # noqa: F401
  467. import torch # noqa: F401
  468. except ImportError:
  469. logger.error(
  470. "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
  471. "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
  472. )
  473. raise
  474. tf_state_dict = {tf_weight.name: tf_weight.numpy() for tf_weight in tf_weights}
  475. return load_tf2_state_dict_in_pytorch_model(
  476. pt_model, tf_state_dict, allow_missing_keys=allow_missing_keys, output_loading_info=output_loading_info
  477. )
  478. def load_tf2_state_dict_in_pytorch_model(pt_model, tf_state_dict, allow_missing_keys=False, output_loading_info=False):
  479. import torch
  480. new_pt_params_dict = {}
  481. current_pt_params_dict = dict(pt_model.named_parameters())
  482. # Make sure we are able to load PyTorch base models as well as derived models (with heads)
  483. # TF models always have a prefix, some of PyTorch models (base ones) don't
  484. start_prefix_to_remove = ""
  485. if not any(s.startswith(pt_model.base_model_prefix) for s in current_pt_params_dict.keys()):
  486. start_prefix_to_remove = pt_model.base_model_prefix + "."
  487. # Build a map from potential PyTorch weight names to TF 2.0 Variables
  488. tf_weights_map = {}
  489. for name, tf_weight in tf_state_dict.items():
  490. pt_name, transpose = convert_tf_weight_name_to_pt_weight_name(
  491. name, start_prefix_to_remove=start_prefix_to_remove, tf_weight_shape=tf_weight.shape
  492. )
  493. tf_weights_map[pt_name] = (tf_weight, transpose)
  494. all_tf_weights = set(tf_weights_map.keys())
  495. loaded_pt_weights_data_ptr = {}
  496. missing_keys_pt = []
  497. for pt_weight_name, pt_weight in current_pt_params_dict.items():
  498. # Handle PyTorch shared weight ()not duplicated in TF 2.0
  499. if pt_weight.data_ptr() in loaded_pt_weights_data_ptr:
  500. new_pt_params_dict[pt_weight_name] = loaded_pt_weights_data_ptr[pt_weight.data_ptr()]
  501. continue
  502. pt_weight_name_to_check = pt_weight_name
  503. # New `weight_norm` from https://github.com/huggingface/transformers/pull/24030
  504. key_components = pt_weight_name.split(".")
  505. name = None
  506. if key_components[-3::2] == ["parametrizations", "original0"]:
  507. name = key_components[-2] + "_g"
  508. elif key_components[-3::2] == ["parametrizations", "original1"]:
  509. name = key_components[-2] + "_v"
  510. if name is not None:
  511. key_components = key_components[:-3] + [name]
  512. pt_weight_name_to_check = ".".join(key_components)
  513. # Find associated numpy array in pytorch model state dict
  514. if pt_weight_name_to_check not in tf_weights_map:
  515. if allow_missing_keys:
  516. missing_keys_pt.append(pt_weight_name)
  517. continue
  518. raise AttributeError(f"{pt_weight_name} not found in TF 2.0 model")
  519. array, transpose = tf_weights_map[pt_weight_name_to_check]
  520. array = apply_transpose(transpose, array, pt_weight.shape, pt_to_tf=False)
  521. if numpy.isscalar(array):
  522. array = numpy.array(array)
  523. if not is_torch_tensor(array) and not is_numpy_array(array):
  524. array = array.numpy()
  525. if is_numpy_array(array):
  526. # Convert to torch tensor
  527. array = torch.from_numpy(array)
  528. new_pt_params_dict[pt_weight_name] = array
  529. loaded_pt_weights_data_ptr[pt_weight.data_ptr()] = array
  530. all_tf_weights.discard(pt_weight_name)
  531. missing_keys, unexpected_keys = pt_model.load_state_dict(new_pt_params_dict, strict=False)
  532. missing_keys += missing_keys_pt
  533. # Some models may have keys that are not in the state by design, removing them before needlessly warning
  534. # the user.
  535. if pt_model._keys_to_ignore_on_load_missing is not None:
  536. for pat in pt_model._keys_to_ignore_on_load_missing:
  537. missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
  538. if pt_model._keys_to_ignore_on_load_unexpected is not None:
  539. for pat in pt_model._keys_to_ignore_on_load_unexpected:
  540. unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
  541. if len(unexpected_keys) > 0:
  542. logger.warning(
  543. "Some weights of the TF 2.0 model were not used when initializing the PyTorch model"
  544. f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing"
  545. f" {pt_model.__class__.__name__} from a TF 2.0 model trained on another task or with another architecture"
  546. " (e.g. initializing a BertForSequenceClassification model from a TFBertForPreTraining model).\n- This IS"
  547. f" NOT expected if you are initializing {pt_model.__class__.__name__} from a TF 2.0 model that you expect"
  548. " to be exactly identical (e.g. initializing a BertForSequenceClassification model from a"
  549. " TFBertForSequenceClassification model)."
  550. )
  551. else:
  552. logger.warning(f"All TF 2.0 model weights were used when initializing {pt_model.__class__.__name__}.\n")
  553. if len(missing_keys) > 0:
  554. logger.warning(
  555. f"Some weights of {pt_model.__class__.__name__} were not initialized from the TF 2.0 model and are newly"
  556. f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to"
  557. " use it for predictions and inference."
  558. )
  559. else:
  560. logger.warning(
  561. f"All the weights of {pt_model.__class__.__name__} were initialized from the TF 2.0 model.\n"
  562. "If your task is similar to the task the model of the checkpoint was trained on, "
  563. f"you can already use {pt_model.__class__.__name__} for predictions without further training."
  564. )
  565. logger.info(f"Weights or buffers not loaded from TF 2.0 model: {all_tf_weights}")
  566. if output_loading_info:
  567. loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys}
  568. return pt_model, loading_info
  569. return pt_model