modeling_flax_pytorch_utils.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496
  1. # coding=utf-8
  2. # Copyright 2021 The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch - Flax general utilities."""
  16. import os
  17. from pickle import UnpicklingError
  18. from typing import Dict, Tuple
  19. import jax
  20. import jax.numpy as jnp
  21. import numpy as np
  22. from flax.serialization import from_bytes
  23. from flax.traverse_util import flatten_dict, unflatten_dict
  24. import transformers
  25. from . import is_safetensors_available, is_torch_available
  26. from .utils import logging
  27. if is_torch_available():
  28. import torch
  29. if is_safetensors_available():
  30. from safetensors import safe_open
  31. from safetensors.flax import load_file as safe_load_file
  32. logger = logging.get_logger(__name__)
  33. #####################
  34. # PyTorch => Flax #
  35. #####################
  36. def load_pytorch_checkpoint_in_flax_state_dict(
  37. flax_model, pytorch_checkpoint_path, is_sharded, allow_missing_keys=False
  38. ):
  39. """Load pytorch checkpoints in a flax model"""
  40. if not is_sharded:
  41. pt_path = os.path.abspath(pytorch_checkpoint_path)
  42. logger.info(f"Loading PyTorch weights from {pt_path}")
  43. if pt_path.endswith(".safetensors"):
  44. pt_state_dict = {}
  45. with safe_open(pt_path, framework="flax") as f:
  46. for k in f.keys():
  47. pt_state_dict[k] = f.get_tensor(k)
  48. else:
  49. try:
  50. import torch # noqa: F401
  51. from .pytorch_utils import is_torch_greater_or_equal_than_1_13 # noqa: F401
  52. except (ImportError, ModuleNotFoundError):
  53. logger.error(
  54. "Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
  55. " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
  56. " instructions."
  57. )
  58. raise
  59. weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
  60. pt_state_dict = torch.load(pt_path, map_location="cpu", **weights_only_kwarg)
  61. logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")
  62. flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
  63. else:
  64. # model is sharded and pytorch_checkpoint_path already contains the list of .pt shard files
  65. flax_state_dict = convert_pytorch_sharded_state_dict_to_flax(pytorch_checkpoint_path, flax_model)
  66. return flax_state_dict
  67. def rename_key_and_reshape_tensor(
  68. pt_tuple_key: Tuple[str],
  69. pt_tensor: np.ndarray,
  70. random_flax_state_dict: Dict[str, jnp.ndarray],
  71. model_prefix: str,
  72. ) -> (Tuple[str], np.ndarray):
  73. """Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""
  74. def is_key_or_prefix_key_in_dict(key: Tuple[str]) -> bool:
  75. """Checks if `key` of `(prefix,) + key` is in random_flax_state_dict"""
  76. return len(set(random_flax_state_dict) & {key, (model_prefix,) + key}) > 0
  77. # layer norm
  78. renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
  79. if pt_tuple_key[-1] in ["weight", "gamma"] and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key):
  80. return renamed_pt_tuple_key, pt_tensor
  81. # batch norm layer mean
  82. renamed_pt_tuple_key = pt_tuple_key[:-1] + ("mean",)
  83. if pt_tuple_key[-1] == "running_mean" and not is_key_or_prefix_key_in_dict(pt_tuple_key):
  84. return renamed_pt_tuple_key, pt_tensor
  85. # batch norm layer var
  86. renamed_pt_tuple_key = pt_tuple_key[:-1] + ("var",)
  87. if pt_tuple_key[-1] == "running_var" and not is_key_or_prefix_key_in_dict(pt_tuple_key):
  88. return renamed_pt_tuple_key, pt_tensor
  89. # embedding
  90. renamed_pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
  91. if pt_tuple_key[-1] == "weight" and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key):
  92. return renamed_pt_tuple_key, pt_tensor
  93. # conv layer
  94. renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
  95. if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4 and not is_key_or_prefix_key_in_dict(pt_tuple_key):
  96. pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
  97. return renamed_pt_tuple_key, pt_tensor
  98. # linear layer
  99. renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
  100. if pt_tuple_key[-1] == "weight" and not is_key_or_prefix_key_in_dict(pt_tuple_key):
  101. pt_tensor = pt_tensor.T
  102. return renamed_pt_tuple_key, pt_tensor
  103. # old PyTorch layer norm weight
  104. renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
  105. if pt_tuple_key[-1] == "gamma":
  106. return renamed_pt_tuple_key, pt_tensor
  107. # old PyTorch layer norm bias
  108. renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
  109. if pt_tuple_key[-1] == "beta":
  110. return renamed_pt_tuple_key, pt_tensor
  111. # New `weight_norm` from https://github.com/huggingface/transformers/pull/24030
  112. name = None
  113. if pt_tuple_key[-3::2] == ("parametrizations", "original0"):
  114. name = pt_tuple_key[-2] + "_g"
  115. elif pt_tuple_key[-3::2] == ("parametrizations", "original1"):
  116. name = pt_tuple_key[-2] + "_v"
  117. if name is not None:
  118. renamed_pt_tuple_key = pt_tuple_key[:-3] + (name,)
  119. return renamed_pt_tuple_key, pt_tensor
  120. return pt_tuple_key, pt_tensor
  121. def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
  122. # convert pytorch tensor to numpy
  123. from_bin = is_torch_available() and isinstance(next(iter(pt_state_dict.values())), torch.Tensor)
  124. bfloat16 = torch.bfloat16 if from_bin else "bfloat16"
  125. weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()}
  126. if from_bin:
  127. for k, v in pt_state_dict.items():
  128. # numpy currently does not support bfloat16, need to go over float32 in this case to not lose precision
  129. if v.dtype == bfloat16:
  130. v = v.float()
  131. pt_state_dict[k] = v.cpu().numpy()
  132. model_prefix = flax_model.base_model_prefix
  133. # use params dict if the model contains batch norm layers
  134. if "params" in flax_model.params:
  135. flax_model_params = flax_model.params["params"]
  136. else:
  137. flax_model_params = flax_model.params
  138. random_flax_state_dict = flatten_dict(flax_model_params)
  139. # add batch_stats keys,values to dict
  140. if "batch_stats" in flax_model.params:
  141. flax_batch_stats = flatten_dict(flax_model.params["batch_stats"])
  142. random_flax_state_dict.update(flax_batch_stats)
  143. flax_state_dict = {}
  144. load_model_with_head_into_base_model = (model_prefix not in flax_model_params) and (
  145. model_prefix in {k.split(".")[0] for k in pt_state_dict.keys()}
  146. )
  147. load_base_model_into_model_with_head = (model_prefix in flax_model_params) and (
  148. model_prefix not in {k.split(".")[0] for k in pt_state_dict.keys()}
  149. )
  150. # Need to change some parameters name to match Flax names
  151. for pt_key, pt_tensor in pt_state_dict.items():
  152. pt_tuple_key = tuple(pt_key.split("."))
  153. is_bfloat_16 = weight_dtypes[pt_key] == bfloat16
  154. # remove base model prefix if necessary
  155. has_base_model_prefix = pt_tuple_key[0] == model_prefix
  156. if load_model_with_head_into_base_model and has_base_model_prefix:
  157. pt_tuple_key = pt_tuple_key[1:]
  158. # Correctly rename weight parameters
  159. flax_key, flax_tensor = rename_key_and_reshape_tensor(
  160. pt_tuple_key, pt_tensor, random_flax_state_dict, model_prefix
  161. )
  162. # add model prefix if necessary
  163. require_base_model_prefix = (model_prefix,) + flax_key in random_flax_state_dict
  164. if load_base_model_into_model_with_head and require_base_model_prefix:
  165. flax_key = (model_prefix,) + flax_key
  166. if flax_key in random_flax_state_dict:
  167. if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
  168. raise ValueError(
  169. f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
  170. f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
  171. )
  172. # add batch stats if the model contains batchnorm layers
  173. if "batch_stats" in flax_model.params:
  174. if "mean" in flax_key[-1] or "var" in flax_key[-1]:
  175. flax_state_dict[("batch_stats",) + flax_key] = jnp.asarray(flax_tensor)
  176. continue
  177. # remove num_batches_tracked key
  178. if "num_batches_tracked" in flax_key[-1]:
  179. flax_state_dict.pop(flax_key, None)
  180. continue
  181. # also add unexpected weight so that warning is thrown
  182. flax_state_dict[("params",) + flax_key] = (
  183. jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16)
  184. )
  185. else:
  186. # also add unexpected weight so that warning is thrown
  187. flax_state_dict[flax_key] = (
  188. jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16)
  189. )
  190. return unflatten_dict(flax_state_dict)
  191. ############################
  192. # Sharded Pytorch => Flax #
  193. ############################
  194. def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
  195. import torch
  196. from .pytorch_utils import is_torch_greater_or_equal_than_1_13
  197. # Load the index
  198. flax_state_dict = {}
  199. for shard_file in shard_filenames:
  200. # load using msgpack utils
  201. weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
  202. pt_state_dict = torch.load(shard_file, **weights_only_kwarg)
  203. weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()}
  204. pt_state_dict = {
  205. k: v.numpy() if v.dtype != torch.bfloat16 else v.float().numpy() for k, v in pt_state_dict.items()
  206. }
  207. model_prefix = flax_model.base_model_prefix
  208. # use params dict if the model contains batch norm layers and then add batch_stats keys,values to dict
  209. if "batch_stats" in flax_model.params:
  210. flax_model_params = flax_model.params["params"]
  211. random_flax_state_dict = flatten_dict(flax_model_params)
  212. random_flax_state_dict.update(flatten_dict(flax_model.params["batch_stats"]))
  213. else:
  214. flax_model_params = flax_model.params
  215. random_flax_state_dict = flatten_dict(flax_model_params)
  216. load_model_with_head_into_base_model = (model_prefix not in flax_model_params) and (
  217. model_prefix in {k.split(".")[0] for k in pt_state_dict.keys()}
  218. )
  219. load_base_model_into_model_with_head = (model_prefix in flax_model_params) and (
  220. model_prefix not in {k.split(".")[0] for k in pt_state_dict.keys()}
  221. )
  222. # Need to change some parameters name to match Flax names
  223. for pt_key, pt_tensor in pt_state_dict.items():
  224. pt_tuple_key = tuple(pt_key.split("."))
  225. is_bfloat_16 = weight_dtypes[pt_key] == torch.bfloat16
  226. # remove base model prefix if necessary
  227. has_base_model_prefix = pt_tuple_key[0] == model_prefix
  228. if load_model_with_head_into_base_model and has_base_model_prefix:
  229. pt_tuple_key = pt_tuple_key[1:]
  230. # Correctly rename weight parameters
  231. flax_key, flax_tensor = rename_key_and_reshape_tensor(
  232. pt_tuple_key, pt_tensor, random_flax_state_dict, model_prefix
  233. )
  234. # add model prefix if necessary
  235. require_base_model_prefix = (model_prefix,) + flax_key in random_flax_state_dict
  236. if load_base_model_into_model_with_head and require_base_model_prefix:
  237. flax_key = (model_prefix,) + flax_key
  238. if flax_key in random_flax_state_dict:
  239. if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
  240. raise ValueError(
  241. f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
  242. f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
  243. )
  244. # add batch stats if the model contains batchnorm layers
  245. if "batch_stats" in flax_model.params:
  246. if "mean" in flax_key[-1]:
  247. flax_state_dict[("batch_stats",) + flax_key] = jnp.asarray(flax_tensor)
  248. continue
  249. if "var" in flax_key[-1]:
  250. flax_state_dict[("batch_stats",) + flax_key] = jnp.asarray(flax_tensor)
  251. continue
  252. # remove num_batches_tracked key
  253. if "num_batches_tracked" in flax_key[-1]:
  254. flax_state_dict.pop(flax_key, None)
  255. continue
  256. # also add unexpected weight so that warning is thrown
  257. flax_state_dict[("params",) + flax_key] = (
  258. jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16)
  259. )
  260. else:
  261. # also add unexpected weight so that warning is thrown
  262. flax_state_dict[flax_key] = (
  263. jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16)
  264. )
  265. return unflatten_dict(flax_state_dict)
  266. #####################
  267. # Flax => PyTorch #
  268. #####################
  269. def load_flax_checkpoint_in_pytorch_model(model, flax_checkpoint_path):
  270. """Load flax checkpoints in a PyTorch model"""
  271. flax_checkpoint_path = os.path.abspath(flax_checkpoint_path)
  272. logger.info(f"Loading Flax weights from {flax_checkpoint_path}")
  273. # import correct flax class
  274. flax_cls = getattr(transformers, "Flax" + model.__class__.__name__)
  275. # load flax weight dict
  276. if flax_checkpoint_path.endswith(".safetensors"):
  277. flax_state_dict = safe_load_file(flax_checkpoint_path)
  278. flax_state_dict = unflatten_dict(flax_state_dict, sep=".")
  279. else:
  280. with open(flax_checkpoint_path, "rb") as state_f:
  281. try:
  282. flax_state_dict = from_bytes(flax_cls, state_f.read())
  283. except UnpicklingError:
  284. raise EnvironmentError(f"Unable to convert {flax_checkpoint_path} to Flax deserializable object. ")
  285. return load_flax_weights_in_pytorch_model(model, flax_state_dict)
  286. def load_flax_weights_in_pytorch_model(pt_model, flax_state):
  287. """Load flax checkpoints in a PyTorch model"""
  288. try:
  289. import torch # noqa: F401
  290. except (ImportError, ModuleNotFoundError):
  291. logger.error(
  292. "Loading a Flax weights in PyTorch, requires both PyTorch and Flax to be installed. Please see"
  293. " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
  294. " instructions."
  295. )
  296. raise
  297. # check if we have bf16 weights
  298. is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values()
  299. if any(is_type_bf16):
  300. # convert all weights to fp32 if the are bf16 since torch.from_numpy can-not handle bf16
  301. # and bf16 is not fully supported in PT yet.
  302. logger.warning(
  303. "Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` "
  304. "before loading those in PyTorch model."
  305. )
  306. flax_state = jax.tree_util.tree_map(
  307. lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state
  308. )
  309. flax_state_dict = flatten_dict(flax_state)
  310. pt_model_dict = pt_model.state_dict()
  311. load_model_with_head_into_base_model = (pt_model.base_model_prefix in flax_state) and (
  312. pt_model.base_model_prefix not in {k.split(".")[0] for k in pt_model_dict.keys()}
  313. )
  314. load_base_model_into_model_with_head = (pt_model.base_model_prefix not in flax_state) and (
  315. pt_model.base_model_prefix in {k.split(".")[0] for k in pt_model_dict.keys()}
  316. )
  317. # keep track of unexpected & missing keys
  318. unexpected_keys = []
  319. missing_keys = set(pt_model_dict.keys())
  320. for flax_key_tuple, flax_tensor in flax_state_dict.items():
  321. has_base_model_prefix = flax_key_tuple[0] == pt_model.base_model_prefix
  322. require_base_model_prefix = ".".join((pt_model.base_model_prefix,) + flax_key_tuple) in pt_model_dict
  323. # adapt flax_key to prepare for loading from/to base model only
  324. if load_model_with_head_into_base_model and has_base_model_prefix:
  325. flax_key_tuple = flax_key_tuple[1:]
  326. elif load_base_model_into_model_with_head and require_base_model_prefix:
  327. flax_key_tuple = (pt_model.base_model_prefix,) + flax_key_tuple
  328. # rename flax weights to PyTorch format
  329. if flax_key_tuple[-1] == "kernel" and flax_tensor.ndim == 4 and ".".join(flax_key_tuple) not in pt_model_dict:
  330. # conv layer
  331. flax_key_tuple = flax_key_tuple[:-1] + ("weight",)
  332. flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1))
  333. elif flax_key_tuple[-1] == "kernel" and ".".join(flax_key_tuple) not in pt_model_dict:
  334. # linear layer
  335. flax_key_tuple = flax_key_tuple[:-1] + ("weight",)
  336. flax_tensor = flax_tensor.T
  337. elif flax_key_tuple[-1] in ["scale", "embedding"]:
  338. flax_key_tuple = flax_key_tuple[:-1] + ("weight",)
  339. # adding batch stats from flax batch norm to pt
  340. elif "mean" in flax_key_tuple[-1]:
  341. flax_key_tuple = flax_key_tuple[:-1] + ("running_mean",)
  342. elif "var" in flax_key_tuple[-1]:
  343. flax_key_tuple = flax_key_tuple[:-1] + ("running_var",)
  344. if "batch_stats" in flax_state:
  345. flax_key = ".".join(flax_key_tuple[1:]) # Remove the params/batch_stats header
  346. else:
  347. flax_key = ".".join(flax_key_tuple)
  348. # We also need to look at `pt_model_dict` and see if there are keys requiring further transformation.
  349. special_pt_names = {}
  350. # New `weight_norm` from https://github.com/huggingface/transformers/pull/24030
  351. for key in pt_model_dict:
  352. key_components = key.split(".")
  353. name = None
  354. if key_components[-3::2] == ["parametrizations", "original0"]:
  355. name = key_components[-2] + "_g"
  356. elif key_components[-3::2] == ["parametrizations", "original1"]:
  357. name = key_components[-2] + "_v"
  358. if name is not None:
  359. key_components = key_components[:-3] + [name]
  360. key_to_check = ".".join(key_components)
  361. special_pt_names[key_to_check] = key
  362. if flax_key in special_pt_names:
  363. flax_key = special_pt_names[flax_key]
  364. if flax_key in pt_model_dict:
  365. if flax_tensor.shape != pt_model_dict[flax_key].shape:
  366. raise ValueError(
  367. f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected "
  368. f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}."
  369. )
  370. else:
  371. # add weight to pytorch dict
  372. flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor
  373. pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)
  374. # remove from missing keys
  375. missing_keys.remove(flax_key)
  376. else:
  377. # weight is not expected by PyTorch model
  378. unexpected_keys.append(flax_key)
  379. pt_model.load_state_dict(pt_model_dict)
  380. # re-transform missing_keys to list
  381. missing_keys = list(missing_keys)
  382. if len(unexpected_keys) > 0:
  383. logger.warning(
  384. "Some weights of the Flax model were not used when initializing the PyTorch model"
  385. f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing"
  386. f" {pt_model.__class__.__name__} from a Flax model trained on another task or with another architecture"
  387. " (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n- This"
  388. f" IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect"
  389. " to be exactly identical (e.g. initializing a BertForSequenceClassification model from a"
  390. " FlaxBertForSequenceClassification model)."
  391. )
  392. else:
  393. logger.warning(f"All Flax model weights were used when initializing {pt_model.__class__.__name__}.\n")
  394. if len(missing_keys) > 0:
  395. logger.warning(
  396. f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model and are newly"
  397. f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to"
  398. " use it for predictions and inference."
  399. )
  400. else:
  401. logger.warning(
  402. f"All the weights of {pt_model.__class__.__name__} were initialized from the Flax model.\n"
  403. "If your task is similar to the task the model of the checkpoint was trained on, "
  404. f"you can already use {pt_model.__class__.__name__} for predictions without further training."
  405. )
  406. return pt_model