configuration_olmo.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. # coding=utf-8
  2. # Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  5. # and OPT implementations in this library. It has been modified from its
  6. # original forms to accommodate minor architectural differences compared
  7. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. """OLMo model configuration"""
  21. from ...configuration_utils import PretrainedConfig
  22. from ...utils import logging
  23. logger = logging.get_logger(__name__)
  24. class OlmoConfig(PretrainedConfig):
  25. r"""
  26. This is the configuration class to store the configuration of a [`OlmoModel`]. It is used to instantiate an OLMo
  27. model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
  28. defaults will yield a similar configuration to that of the [allenai/OLMo-7B-hf](https://huggingface.co/allenai/OLMo-7B-hf).
  29. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  30. documentation from [`PretrainedConfig`] for more information.
  31. Args:
  32. vocab_size (`int`, *optional*, defaults to 50304):
  33. Vocabulary size of the OLMo model. Defines the number of different tokens that can be represented by the
  34. `inputs_ids` passed when calling [`OlmoModel`]
  35. hidden_size (`int`, *optional*, defaults to 4096):
  36. Dimension of the hidden representations.
  37. intermediate_size (`int`, *optional*, defaults to 11008):
  38. Dimension of the MLP representations.
  39. num_hidden_layers (`int`, *optional*, defaults to 32):
  40. Number of hidden layers in the Transformer decoder.
  41. num_attention_heads (`int`, *optional*, defaults to 32):
  42. Number of attention heads for each attention layer in the Transformer decoder.
  43. num_key_value_heads (`int`, *optional*):
  44. This is the number of key_value heads that should be used to implement Grouped Query Attention. If
  45. `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
  46. `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
  47. converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
  48. by meanpooling all the original heads within that group. For more details checkout [this
  49. paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
  50. `num_attention_heads`.
  51. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
  52. The non-linear activation function (function or string) in the decoder.
  53. max_position_embeddings (`int`, *optional*, defaults to 2048):
  54. The maximum sequence length that this model might ever be used with.
  55. initializer_range (`float`, *optional*, defaults to 0.02):
  56. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  57. use_cache (`bool`, *optional*, defaults to `True`):
  58. Whether or not the model should return the last key/values attentions (not used by all models). Only
  59. relevant if `config.is_decoder=True`.
  60. pad_token_id (`int`, *optional*, defaults to 1):
  61. Padding token id.
  62. bos_token_id (`int`, *optional*):
  63. Beginning of stream token id.
  64. eos_token_id (`int`, *optional*, defaults to 50279):
  65. End of stream token id.
  66. tie_word_embeddings (`bool`, *optional*, defaults to `False`):
  67. Whether to tie weight embeddings
  68. rope_theta (`float`, *optional*, defaults to 10000.0):
  69. The base period of the RoPE embeddings.
  70. rope_scaling (`Dict`, *optional*):
  71. Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
  72. strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
  73. `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
  74. `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
  75. these scaling strategies behave:
  76. https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
  77. experimental feature, subject to breaking API changes in future versions.
  78. attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
  79. Whether to use a bias in the query, key, value and output projection layers during self-attention.
  80. attention_dropout (`float`, *optional*, defaults to 0.0):
  81. The dropout ratio for the attention probabilities.
  82. clip_qkv (`float`, *optional*):
  83. If not `None`, elements of query, key and value attention states are clipped so that their
  84. absolute value does not exceed this value.
  85. ```python
  86. >>> from transformers import OlmoModel, OlmoConfig
  87. >>> # Initializing a OLMo 7B style configuration
  88. >>> configuration = OlmoConfig()
  89. >>> # Initializing a model from the OLMo 7B style configuration
  90. >>> model = OlmoModel(configuration)
  91. >>> # Accessing the model configuration
  92. >>> configuration = model.config
  93. ```"""
  94. model_type = "olmo"
  95. keys_to_ignore_at_inference = ["past_key_values"]
  96. def __init__(
  97. self,
  98. vocab_size=50304,
  99. hidden_size=4096,
  100. intermediate_size=11008,
  101. num_hidden_layers=32,
  102. num_attention_heads=32,
  103. num_key_value_heads=None,
  104. hidden_act="silu",
  105. max_position_embeddings=2048,
  106. initializer_range=0.02,
  107. use_cache=True,
  108. pad_token_id=1,
  109. bos_token_id=None,
  110. eos_token_id=50279,
  111. tie_word_embeddings=False,
  112. rope_theta=10000.0,
  113. rope_scaling=None,
  114. attention_bias=False,
  115. attention_dropout=0.0,
  116. clip_qkv=None,
  117. **kwargs,
  118. ):
  119. self.vocab_size = vocab_size
  120. self.max_position_embeddings = max_position_embeddings
  121. self.hidden_size = hidden_size
  122. self.intermediate_size = intermediate_size
  123. self.num_hidden_layers = num_hidden_layers
  124. self.num_attention_heads = num_attention_heads
  125. # for backward compatibility
  126. if num_key_value_heads is None:
  127. num_key_value_heads = num_attention_heads
  128. self.num_key_value_heads = num_key_value_heads
  129. self.hidden_act = hidden_act
  130. self.initializer_range = initializer_range
  131. self.use_cache = use_cache
  132. self.rope_theta = rope_theta
  133. self.rope_scaling = rope_scaling
  134. self._rope_scaling_validation()
  135. self.attention_bias = attention_bias
  136. self.attention_dropout = attention_dropout
  137. self.clip_qkv = clip_qkv
  138. super().__init__(
  139. pad_token_id=pad_token_id,
  140. bos_token_id=bos_token_id,
  141. eos_token_id=eos_token_id,
  142. tie_word_embeddings=tie_word_embeddings,
  143. **kwargs,
  144. )
  145. def _rope_scaling_validation(self):
  146. """
  147. Validate the `rope_scaling` configuration.
  148. """
  149. if self.rope_scaling is None:
  150. return
  151. if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
  152. raise ValueError(
  153. "`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}"
  154. )
  155. rope_scaling_type = self.rope_scaling.get("type", None)
  156. rope_scaling_factor = self.rope_scaling.get("factor", None)
  157. if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
  158. raise ValueError(
  159. f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
  160. )
  161. if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
  162. raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")