modeling_rope_utils.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573
  1. # Copyright 2024 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. from typing import Optional, Tuple
  16. from .configuration_utils import PretrainedConfig
  17. from .utils import is_torch_available, logging
  18. logger = logging.get_logger(__name__)
  19. if is_torch_available():
  20. import torch
  21. def _compute_default_rope_parameters(
  22. config: Optional[PretrainedConfig] = None,
  23. device: Optional["torch.device"] = None,
  24. seq_len: Optional[int] = None,
  25. **rope_kwargs,
  26. ) -> Tuple["torch.Tensor", float]:
  27. """
  28. Computes the inverse frequencies according to the original RoPE implementation
  29. Args:
  30. config ([`~transformers.PretrainedConfig`]):
  31. The model configuration.
  32. device (`torch.device`):
  33. The device to use for initialization of the inverse frequencies.
  34. seq_len (`int`, *optional*):
  35. The current sequence length. Unused for this type of RoPE.
  36. rope_kwargs (`Dict`, *optional*):
  37. BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
  38. Returns:
  39. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  40. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  41. """
  42. if config is not None and len(rope_kwargs) > 0:
  43. raise ValueError(
  44. "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
  45. f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
  46. )
  47. if len(rope_kwargs) > 0:
  48. base = rope_kwargs["base"]
  49. dim = rope_kwargs["dim"]
  50. elif config is not None:
  51. base = config.rope_theta
  52. partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
  53. head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  54. dim = int(head_dim * partial_rotary_factor)
  55. attention_factor = 1.0 # Unused in this type of RoPE
  56. # Compute the inverse frequencies
  57. inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
  58. return inv_freq, attention_factor
  59. def _compute_linear_scaling_rope_parameters(
  60. config: Optional[PretrainedConfig] = None,
  61. device: Optional["torch.device"] = None,
  62. seq_len: Optional[int] = None,
  63. **rope_kwargs,
  64. ) -> Tuple["torch.Tensor", float]:
  65. """
  66. Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev
  67. Args:
  68. config ([`~transformers.PretrainedConfig`]):
  69. The model configuration.
  70. device (`torch.device`):
  71. The device to use for initialization of the inverse frequencies.
  72. seq_len (`int`, *optional*):
  73. The current sequence length. Unused for this type of RoPE.
  74. rope_kwargs (`Dict`, *optional*):
  75. BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
  76. Returns:
  77. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  78. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  79. """
  80. if config is not None and len(rope_kwargs) > 0:
  81. raise ValueError(
  82. "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
  83. f"`_compute_linear_scaling_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
  84. )
  85. if len(rope_kwargs) > 0:
  86. factor = rope_kwargs["factor"]
  87. elif config is not None:
  88. factor = config.rope_scaling["factor"]
  89. # Gets the default RoPE parameters
  90. inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs)
  91. # Then applies linear scaling to the frequencies.
  92. # NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so
  93. # applying scaling to the inverse frequencies is equivalent.
  94. inv_freq /= factor
  95. return inv_freq, attention_factor
  96. def _compute_dynamic_ntk_parameters(
  97. config: Optional[PretrainedConfig] = None,
  98. device: Optional["torch.device"] = None,
  99. seq_len: Optional[int] = None,
  100. **rope_kwargs,
  101. ) -> Tuple["torch.Tensor", float]:
  102. """
  103. Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
  104. Args:
  105. config ([`~transformers.PretrainedConfig`]):
  106. The model configuration.
  107. device (`torch.device`):
  108. The device to use for initialization of the inverse frequencies.
  109. seq_len (`int`, *optional*):
  110. The current sequence length, used to update the dynamic RoPE at inference time.
  111. rope_kwargs (`Dict`, *optional*):
  112. BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
  113. Returns:
  114. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  115. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  116. """
  117. # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
  118. if config is not None and len(rope_kwargs) > 0:
  119. raise ValueError(
  120. "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
  121. f"`_compute_dynamic_ntk_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
  122. )
  123. if len(rope_kwargs) > 0:
  124. base = rope_kwargs["base"]
  125. dim = rope_kwargs["dim"]
  126. max_position_embeddings = rope_kwargs["max_position_embeddings"]
  127. factor = rope_kwargs["factor"]
  128. elif config is not None:
  129. base = config.rope_theta
  130. partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
  131. head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  132. dim = int(head_dim * partial_rotary_factor)
  133. max_position_embeddings = config.max_position_embeddings
  134. factor = config.rope_scaling["factor"]
  135. attention_factor = 1.0 # Unused in this type of RoPE
  136. # seq_len: default to max_position_embeddings, e.g. at init time
  137. seq_len = seq_len if seq_len is not None and seq_len > max_position_embeddings else max_position_embeddings
  138. # Compute the inverse frequencies
  139. base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
  140. inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
  141. return inv_freq, attention_factor
  142. def _compute_yarn_parameters(
  143. config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
  144. ) -> Tuple["torch.Tensor", float]:
  145. """
  146. Computes the inverse frequencies with NTK scaling. Please refer to the
  147. [original paper](https://arxiv.org/abs/2309.00071)
  148. Args:
  149. config ([`~transformers.PretrainedConfig`]):
  150. The model configuration.
  151. device (`torch.device`):
  152. The device to use for initialization of the inverse frequencies.
  153. seq_len (`int`, *optional*):
  154. The current sequence length. Unused for this type of RoPE.
  155. rope_kwargs (`Dict`, *optional*):
  156. BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
  157. Returns:
  158. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  159. post-processing scaling factor applied to the computed cos/sin.
  160. """
  161. # No need to keep BC with yarn, unreleased when this new pattern was created.
  162. if len(rope_kwargs) > 0:
  163. raise ValueError(
  164. f"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_yarn_parameters`, got {rope_kwargs}"
  165. )
  166. base = config.rope_theta
  167. partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
  168. head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  169. dim = int(head_dim * partial_rotary_factor)
  170. max_position_embeddings = config.max_position_embeddings
  171. factor = config.rope_scaling["factor"]
  172. # Sets the attention factor as suggested in the paper
  173. attention_factor = config.rope_scaling.get("attention_factor")
  174. if attention_factor is None:
  175. attention_factor = 0.1 * math.log(factor) + 1.0
  176. # Optional config options
  177. # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
  178. beta_fast = config.rope_scaling.get("beta_fast") or 32
  179. beta_slow = config.rope_scaling.get("beta_slow") or 1
  180. # Compute the inverse frequencies
  181. def find_correction_dim(num_rotations, dim, base, max_position_embeddings):
  182. """Inverse dimension formula to find the dimension based on the number of rotations"""
  183. return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
  184. def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings):
  185. """Find dimension range bounds based on rotations"""
  186. low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))
  187. high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
  188. return max(low, 0), min(high, dim - 1)
  189. def linear_ramp_factor(min, max, dim):
  190. if min == max:
  191. max += 0.001 # Prevent singularity
  192. linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
  193. ramp_func = torch.clamp(linear_func, 0, 1)
  194. return ramp_func
  195. # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
  196. # to expand the possible context length. In other words, interpolation = apply scaling factor.
  197. pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim)
  198. inv_freq_extrapolation = 1.0 / pos_freqs
  199. inv_freq_interpolation = 1.0 / (factor * pos_freqs)
  200. low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings)
  201. # Get n-dimensional rotational scaling corrected for extrapolation
  202. inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device)
  203. inv_freq = (
  204. inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
  205. + inv_freq_extrapolation * inv_freq_extrapolation_factor
  206. )
  207. return inv_freq, attention_factor
  208. def _compute_longrope_parameters(
  209. config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
  210. ) -> Tuple["torch.Tensor", float]:
  211. """
  212. Computes the inverse frequencies with LongRoPE scaling. Please refer to the
  213. [original implementation](https://github.com/microsoft/LongRoPE)
  214. Args:
  215. config ([`~transformers.PretrainedConfig`]):
  216. The model configuration.
  217. device (`torch.device`):
  218. The device to use for initialization of the inverse frequencies.
  219. seq_len (`int`, *optional*):
  220. The current sequence length.
  221. rope_kwargs (`Dict`, *optional*):
  222. BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
  223. Returns:
  224. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  225. post-processing scaling factor applied to the computed cos/sin.
  226. """
  227. # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
  228. # No need to keep BC with longrope, unreleased when this new pattern was created.
  229. if len(rope_kwargs) > 0:
  230. raise ValueError(
  231. "Unexpected arguments: `**rope_kwargs` should be unset in `_compute_longrope_parameters`, got "
  232. f"{rope_kwargs}"
  233. )
  234. base = config.rope_theta
  235. partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
  236. head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  237. dim = int(head_dim * partial_rotary_factor)
  238. long_factor = config.rope_scaling["long_factor"]
  239. short_factor = config.rope_scaling["short_factor"]
  240. factor = config.rope_scaling.get("factor")
  241. attention_factor = config.rope_scaling.get("attention_factor")
  242. # NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a
  243. # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
  244. # values to compute the default attention scaling factor, instead of using `factor`.
  245. if hasattr(config, "original_max_position_embeddings"):
  246. if seq_len and seq_len < config.original_max_position_embeddings:
  247. expanded_max_position_embeddings = config.original_max_position_embeddings
  248. else:
  249. expanded_max_position_embeddings = config.max_position_embeddings
  250. max_position_embeddings = config.original_max_position_embeddings
  251. factor = expanded_max_position_embeddings / max_position_embeddings
  252. else:
  253. max_position_embeddings = config.max_position_embeddings
  254. expanded_max_position_embeddings = max_position_embeddings * factor
  255. # Sets the attention factor as suggested in the paper
  256. if attention_factor is None:
  257. if factor <= 1.0:
  258. attention_factor = 1.0
  259. else:
  260. attention_factor = math.sqrt(1 + math.log(factor) / math.log(max_position_embeddings))
  261. # Compute the inverse frequencies -- scaled based on the target sequence length
  262. if expanded_max_position_embeddings > max_position_embeddings:
  263. ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device)
  264. else:
  265. ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device)
  266. inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64, device=device).float() / dim
  267. inv_freq = 1.0 / (ext_factors * base**inv_freq_shape)
  268. return inv_freq, attention_factor
  269. def _compute_llama3_parameters(
  270. config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
  271. ) -> Tuple["torch.Tensor", float]:
  272. """
  273. Computes the inverse frequencies for llama 3.1.
  274. Args:
  275. config ([`~transformers.PretrainedConfig`]):
  276. The model configuration.
  277. device (`torch.device`):
  278. The device to use for initialization of the inverse frequencies.
  279. seq_len (`int`, *optional*):
  280. The current sequence length. Unused for this type of RoPE.
  281. rope_kwargs (`Dict`, *optional*):
  282. BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
  283. Returns:
  284. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  285. post-processing scaling factor applied to the computed cos/sin.
  286. """
  287. # Gets the default RoPE parameters
  288. inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs)
  289. factor = config.rope_scaling["factor"] # `8` in the original implementation
  290. low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation
  291. high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation
  292. old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation
  293. low_freq_wavelen = old_context_len / low_freq_factor
  294. high_freq_wavelen = old_context_len / high_freq_factor
  295. wavelen = 2 * math.pi / inv_freq
  296. # wavelen < high_freq_wavelen: do nothing
  297. # wavelen > low_freq_wavelen: divide by factor
  298. inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
  299. # otherwise: interpolate between the two, using a smooth factor
  300. smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
  301. smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
  302. is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
  303. inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
  304. return inv_freq_llama, attention_factor
  305. # This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
  306. # from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE
  307. # parameterizations, as long as the callable has the same signature.
  308. ROPE_INIT_FUNCTIONS = {
  309. "default": _compute_default_rope_parameters,
  310. "linear": _compute_linear_scaling_rope_parameters,
  311. "dynamic": _compute_dynamic_ntk_parameters,
  312. "yarn": _compute_yarn_parameters,
  313. "longrope": _compute_longrope_parameters,
  314. "llama3": _compute_llama3_parameters,
  315. }
  316. def _check_received_keys(
  317. rope_type: str,
  318. received_keys: set,
  319. required_keys: set,
  320. optional_keys: Optional[set] = None,
  321. ignore_keys: Optional[set] = None,
  322. ):
  323. """Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
  324. # BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present
  325. if "type" in received_keys:
  326. received_keys -= {"type"}
  327. required_keys.add("rope_type")
  328. # Some models need to store model-specific keys, and we don't want to throw warning at them
  329. if ignore_keys is not None:
  330. received_keys -= ignore_keys
  331. missing_keys = required_keys - received_keys
  332. if missing_keys:
  333. raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}")
  334. if optional_keys is not None:
  335. unused_keys = received_keys - required_keys - optional_keys
  336. else:
  337. unused_keys = received_keys - required_keys
  338. if unused_keys:
  339. logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}")
  340. def _validate_default_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
  341. rope_scaling = config.rope_scaling
  342. rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
  343. required_keys = {"rope_type"}
  344. received_keys = set(rope_scaling.keys())
  345. _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
  346. def _validate_linear_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
  347. rope_scaling = config.rope_scaling
  348. rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
  349. required_keys = {"rope_type", "factor"}
  350. received_keys = set(rope_scaling.keys())
  351. _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
  352. factor = rope_scaling["factor"]
  353. if factor is None or not isinstance(factor, float) or factor < 1.0:
  354. logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
  355. def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
  356. rope_scaling = config.rope_scaling
  357. rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
  358. required_keys = {"rope_type", "factor"}
  359. # TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
  360. optional_keys = {"original_max_position_embeddings"}
  361. received_keys = set(rope_scaling.keys())
  362. _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
  363. factor = rope_scaling["factor"]
  364. if factor is None or not isinstance(factor, float) or factor < 1.0:
  365. logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
  366. def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
  367. rope_scaling = config.rope_scaling
  368. rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
  369. required_keys = {"rope_type", "factor"}
  370. optional_keys = {"attention_factor", "beta_fast", "beta_slow"}
  371. received_keys = set(rope_scaling.keys())
  372. _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
  373. factor = rope_scaling["factor"]
  374. if factor is None or not isinstance(factor, float) or factor < 1.0:
  375. logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
  376. attention_factor = rope_scaling.get("attention_factor")
  377. if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0):
  378. logger.warning(
  379. f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
  380. )
  381. beta_fast = rope_scaling.get("beta_fast")
  382. if beta_fast is not None and not isinstance(beta_fast, float):
  383. logger.warning(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}")
  384. beta_slow = rope_scaling.get("beta_slow")
  385. if beta_slow is not None and not isinstance(beta_slow, float):
  386. logger.warning(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}")
  387. if (beta_fast or 32) < (beta_slow or 1):
  388. logger.warning(
  389. f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} "
  390. f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)"
  391. )
  392. def _validate_longrope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
  393. rope_scaling = config.rope_scaling
  394. rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
  395. required_keys = {"rope_type", "short_factor", "long_factor"}
  396. # TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
  397. optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"}
  398. received_keys = set(rope_scaling.keys())
  399. _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
  400. partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
  401. head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  402. dim = int(head_dim * partial_rotary_factor)
  403. short_factor = rope_scaling.get("short_factor")
  404. if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor):
  405. logger.warning(f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}")
  406. if not len(short_factor) == dim // 2:
  407. logger.warning(f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}")
  408. long_factor = rope_scaling.get("long_factor")
  409. if not isinstance(long_factor, list) and all(isinstance(x, (int, float)) for x in long_factor):
  410. logger.warning(f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}")
  411. if not len(long_factor) == dim // 2:
  412. logger.warning(f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}")
  413. # Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over
  414. # `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_scaling` and is
  415. # unique to longrope (= undesirable)
  416. if hasattr(config, "original_max_position_embeddings"):
  417. logger.warning_once(
  418. "This model has set a `original_max_position_embeddings` field, to be used together with "
  419. "`max_position_embeddings` to determine a scaling factor. Please set the `factor` field of `rope_scaling`"
  420. "with this ratio instead -- we recommend the use of this field over `original_max_position_embeddings`, "
  421. "as it is compatible with most model architectures."
  422. )
  423. else:
  424. factor = rope_scaling.get("factor")
  425. if factor is None:
  426. logger.warning("Missing required keys in `rope_scaling`: 'factor'")
  427. elif not isinstance(factor, float) or factor < 1.0:
  428. logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
  429. attention_factor = rope_scaling.get("attention_factor")
  430. if attention_factor is not None:
  431. if not isinstance(attention_factor, float) or attention_factor < 0.0:
  432. logger.warning(
  433. f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
  434. )
  435. def _validate_llama3_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
  436. rope_scaling = config.rope_scaling
  437. rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
  438. required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"}
  439. received_keys = set(rope_scaling.keys())
  440. _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
  441. factor = rope_scaling["factor"]
  442. if factor is None or not isinstance(factor, float) or factor < 1.0:
  443. logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
  444. low_freq_factor = rope_scaling["low_freq_factor"]
  445. high_freq_factor = rope_scaling["high_freq_factor"]
  446. if low_freq_factor is None or not isinstance(low_freq_factor, float):
  447. logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}")
  448. if high_freq_factor is None or not isinstance(high_freq_factor, float):
  449. logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}")
  450. if high_freq_factor <= low_freq_factor:
  451. logger.warning(
  452. "`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
  453. f"{high_freq_factor} and low_freq_factor={low_freq_factor}"
  454. )
  455. original_max_position_embeddings = rope_scaling["original_max_position_embeddings"]
  456. if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int):
  457. logger.warning(
  458. "`rope_scaling`'s original_max_position_embeddings field must be an integer, got "
  459. f"{original_max_position_embeddings}"
  460. )
  461. if original_max_position_embeddings >= config.max_position_embeddings:
  462. logger.warning(
  463. "`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got "
  464. f"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}"
  465. )
  466. # Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types.
  467. ROPE_VALIDATION_FUNCTIONS = {
  468. "default": _validate_default_rope_parameters,
  469. "linear": _validate_linear_scaling_rope_parameters,
  470. "dynamic": _validate_dynamic_scaling_rope_parameters,
  471. "yarn": _validate_yarn_parameters,
  472. "longrope": _validate_longrope_parameters,
  473. "llama3": _validate_llama3_parameters,
  474. }
  475. def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set] = None):
  476. """
  477. Validate the RoPE config arguments, given a `PretrainedConfig` object
  478. """
  479. rope_scaling = getattr(config, "rope_scaling", None) # not a default parameter in `PretrainedConfig`
  480. if rope_scaling is None:
  481. return
  482. # BC: "rope_type" was originally "type"
  483. rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
  484. validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type)
  485. if validation_fn is not None:
  486. validation_fn(config, ignore_keys=ignore_keys)
  487. else:
  488. logger.warning(
  489. f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'"
  490. )