| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234 |
- # coding=utf-8
- # Copyright 2022 the Big Science Workshop and HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Bloom configuration"""
- from collections import OrderedDict
- from typing import TYPE_CHECKING, Any, List, Mapping, Optional
- from packaging import version
- if TYPE_CHECKING:
- from ... import PreTrainedTokenizer, TensorType
- from ...configuration_utils import PretrainedConfig
- from ...onnx import OnnxConfigWithPast, PatchingSpec
- from ...utils import is_torch_available, logging
- logger = logging.get_logger(__name__)
- class BloomConfig(PretrainedConfig):
- """
- This is the configuration class to store the configuration of a [`BloomModel`]. It is used to instantiate a Bloom
- model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
- defaults will yield a similar configuration to the Bloom architecture
- [bigscience/bloom](https://huggingface.co/bigscience/bloom).
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
- documentation from [`PretrainedConfig`] for more information.
- Args:
- vocab_size (`int`, *optional*, defaults to 250880):
- Vocabulary size of the Bloom model. Defines the maximum number of different tokens that can be represented
- by the `inputs_ids` passed when calling [`BloomModel`]. Check [this
- discussion](https://huggingface.co/bigscience/bloom/discussions/120#633d28389addb8530b406c2a) on how the
- `vocab_size` has been defined.
- hidden_size (`int`, *optional*, defaults to 64):
- Dimensionality of the embeddings and hidden states.
- n_layer (`int`, *optional*, defaults to 2):
- Number of hidden layers in the Transformer encoder.
- n_head (`int`, *optional*, defaults to 8):
- Number of attention heads for each attention layer in the Transformer encoder.
- layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
- The epsilon to use in the layer normalization layers.
- initializer_range (`float`, *optional*, defaults to 0.02):
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
- apply_residual_connection_post_layernorm (`bool`, *optional*, defaults to `False`):
- If enabled, use the layer norm of the hidden states as the residual in the transformer blocks
- hidden_dropout (`float`, *optional*, defaults to 0.1):
- Dropout rate of the dropout function on the bias dropout.
- attention_dropout (`float`, *optional*, defaults to 0.1):
- Dropout rate applied to the attention probs
- use_cache (`bool`, *optional*, defaults to `True`):
- Whether or not the model should return the last key/values attentions (not used by all models).
- pretraining_tp (`int`, *optional*, defaults to `1`):
- Experimental feature. Tensor parallelism rank used during pretraining with Megatron. Please refer to [this
- document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
- necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
- issue](https://github.com/pytorch/pytorch/issues/76232). Note also that this is enabled only when
- `slow_but_exact=True`.
- slow_but_exact (`bool`, *optional*, defaults to `False`):
- Experimental feature. Whether to use slow but exact implementation of the attention mechanism. While
- merging the TP rank tensors, due to slicing operations the results may be slightly different between the
- model trained on Megatron and our model. Please refer to [this
- issue](https://github.com/pytorch/pytorch/issues/76232). A solution to obtain more accurate results is to
- enable this feature. Enabling this will hurt the computational time of the inference. Will be probably
- resolved in the future once the main model has been fine-tuned with TP_rank=1.
- Example:
- ```python
- >>> from transformers import BloomConfig, BloomModel
- >>> # Initializing a Bloom configuration
- >>> configuration = BloomConfig()
- >>> # Initializing a model (with random weights) from the configuration
- >>> model = BloomModel(configuration)
- >>> # Accessing the model configuration
- >>> configuration = model.config
- ```"""
- model_type = "bloom"
- keys_to_ignore_at_inference = ["past_key_values"]
- attribute_map = {
- "num_hidden_layers": "n_layer",
- "num_attention_heads": "n_head",
- }
- def __init__(
- self,
- vocab_size=250880,
- hidden_size=64,
- n_layer=2,
- n_head=8,
- layer_norm_epsilon=1e-5,
- initializer_range=0.02,
- use_cache=True,
- bos_token_id=1,
- eos_token_id=2,
- apply_residual_connection_post_layernorm=False,
- hidden_dropout=0.0,
- attention_dropout=0.0,
- pretraining_tp=1, # TP rank used when training with megatron
- slow_but_exact=False,
- **kwargs,
- ):
- self.vocab_size = vocab_size
- # Backward compatibility with n_embed kwarg
- n_embed = kwargs.pop("n_embed", None)
- self.hidden_size = hidden_size if n_embed is None else n_embed
- self.n_layer = n_layer
- self.n_head = n_head
- self.layer_norm_epsilon = layer_norm_epsilon
- self.initializer_range = initializer_range
- self.use_cache = use_cache
- self.pretraining_tp = pretraining_tp
- self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
- self.hidden_dropout = hidden_dropout
- self.attention_dropout = attention_dropout
- self.bos_token_id = bos_token_id
- self.eos_token_id = eos_token_id
- self.slow_but_exact = slow_but_exact
- super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
- class BloomOnnxConfig(OnnxConfigWithPast):
- torch_onnx_minimum_version = version.parse("1.12")
- def __init__(
- self,
- config: PretrainedConfig,
- task: str = "default",
- patching_specs: List[PatchingSpec] = None,
- use_past: bool = False,
- ):
- super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past)
- if not getattr(self._config, "pad_token_id", None):
- # TODO: how to do that better?
- self._config.pad_token_id = 0
- @property
- def inputs(self) -> Mapping[str, Mapping[int, str]]:
- common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
- if self.use_past:
- # BLOOM stores values on dynamic axis 2. For more details see: https://github.com/huggingface/transformers/pull/18344
- self.fill_with_past_key_values_(common_inputs, direction="inputs", inverted_values_shape=True)
- common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"}
- else:
- common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
- return common_inputs
- @property
- def num_layers(self) -> int:
- return self._config.n_layer
- @property
- def num_attention_heads(self) -> int:
- return self._config.n_head
- @property
- def atol_for_validation(self) -> float:
- return 1e-3
- def generate_dummy_inputs(
- self,
- tokenizer: "PreTrainedTokenizer",
- batch_size: int = -1,
- seq_length: int = -1,
- is_pair: bool = False,
- framework: Optional["TensorType"] = None,
- ) -> Mapping[str, Any]:
- common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
- tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
- )
- # We need to order the input in the way they appears in the forward()
- ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
- # Need to add the past_keys
- if self.use_past:
- if not is_torch_available():
- raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
- else:
- import torch
- batch, seqlen = common_inputs["input_ids"].shape
- # Not using the same length for past_key_values
- past_key_values_length = seqlen + 2
- head_dim = self._config.hidden_size // self.num_attention_heads
- past_key_shape = (
- batch * self.num_attention_heads,
- head_dim,
- past_key_values_length,
- )
- past_value_shape = (
- batch * self.num_attention_heads,
- past_key_values_length,
- head_dim,
- )
- ordered_inputs["past_key_values"] = [
- (torch.zeros(past_key_shape), torch.zeros(past_value_shape)) for _ in range(self.num_layers)
- ]
- ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
- if self.use_past:
- mask_dtype = ordered_inputs["attention_mask"].dtype
- ordered_inputs["attention_mask"] = torch.cat(
- [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
- )
- return ordered_inputs
- @property
- def default_onnx_opset(self) -> int:
- return 13
|