| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177 |
- # coding=utf-8
- # Copyright 2022, The LongT5 Authors and HuggingFace Inc.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """LongT5 model configuration"""
- from typing import Mapping
- from ...configuration_utils import PretrainedConfig
- from ...onnx import OnnxSeq2SeqConfigWithPast
- from ...utils import logging
- logger = logging.get_logger(__name__)
- class LongT5Config(PretrainedConfig):
- r"""
- This is the configuration class to store the configuration of a [`LongT5Model`] or a [`FlaxLongT5Model`]. It is
- used to instantiate a LongT5 model according to the specified arguments, defining the model architecture.
- Instantiating a configuration with the defaults will yield a similar configuration to that of the LongT5
- [google/long-t5-local-base](https://huggingface.co/google/long-t5-local-base) architecture.
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
- documentation from [`PretrainedConfig`] for more information.
- Arguments:
- vocab_size (`int`, *optional*, defaults to 32128):
- Vocabulary size of the LongT5 model. Defines the number of different tokens that can be represented by the
- `inputs_ids` passed when calling [`LongT5Model`].
- d_model (`int`, *optional*, defaults to 512):
- Size of the encoder layers and the pooler layer.
- d_kv (`int`, *optional*, defaults to 64):
- Size of the key, query, value projections per attention head. `d_kv` has to be equal to `d_model //
- num_heads`.
- d_ff (`int`, *optional*, defaults to 2048):
- Size of the intermediate feed forward layer in each `LongT5Block`.
- num_layers (`int`, *optional*, defaults to 6):
- Number of hidden layers in the Transformer encoder.
- num_decoder_layers (`int`, *optional*):
- Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set.
- num_heads (`int`, *optional*, defaults to 8):
- Number of attention heads for each attention layer in the Transformer encoder.
- local_radius (`int`, *optional*, defaults to 127)
- Number of tokens to the left/right for each token to locally self-attend in a local attention mechanism.
- global_block_size (`int`, *optional*, defaults to 16)
- Lenght of blocks an input sequence is divided into for a global token representation. Used only for
- `encoder_attention_type = "transient-global"`.
- relative_attention_num_buckets (`int`, *optional*, defaults to 32):
- The number of buckets to use for each attention layer.
- relative_attention_max_distance (`int`, *optional*, defaults to 128):
- The maximum distance of the longer sequences for the bucket separation.
- dropout_rate (`float`, *optional*, defaults to 0.1):
- The ratio for all dropout layers.
- layer_norm_eps (`float`, *optional*, defaults to 1e-6):
- The epsilon used by the layer normalization layers.
- initializer_factor (`float`, *optional*, defaults to 1):
- A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
- testing).
- feed_forward_proj (`string`, *optional*, defaults to `"relu"`):
- Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. LongT5v1.1 uses the
- `"gated-gelu"` feed forward projection. Original LongT5 implementation uses `"gated-gelu"`.
- encoder_attention_type (`string`, *optional*, defaults to `"local"`):
- Type of encoder attention to be used. Should be one of `"local"` or `"transient-global"`, which are
- supported by LongT5 implementation.
- use_cache (`bool`, *optional*, defaults to `True`):
- Whether or not the model should return the last key/values attentions (not used by all models).
- """
- model_type = "longt5"
- keys_to_ignore_at_inference = ["past_key_values"]
- attribute_map = {
- "hidden_size": "d_model",
- "num_attention_heads": "num_heads",
- "num_hidden_layers": "num_layers",
- "head_dim": "d_kv",
- }
- def __init__(
- self,
- vocab_size=32128,
- d_model=512,
- d_kv=64,
- d_ff=2048,
- num_layers=6,
- num_decoder_layers=None,
- num_heads=8,
- local_radius=127,
- global_block_size=16,
- relative_attention_num_buckets=32,
- relative_attention_max_distance=128,
- dropout_rate=0.1,
- layer_norm_epsilon=1e-6,
- initializer_factor=1.0,
- feed_forward_proj="relu",
- is_encoder_decoder=True,
- encoder_attention_type="local",
- use_cache=True,
- pad_token_id=0,
- eos_token_id=1,
- **kwargs,
- ):
- self.vocab_size = vocab_size
- self.d_model = d_model
- self.d_kv = d_kv
- self.d_ff = d_ff
- self.num_layers = num_layers
- # default = symmetry
- self.num_decoder_layers = num_decoder_layers if num_decoder_layers is not None else self.num_layers
- self.num_heads = num_heads
- self.local_radius = local_radius
- self.global_block_size = global_block_size
- self.relative_attention_num_buckets = relative_attention_num_buckets
- self.relative_attention_max_distance = relative_attention_max_distance
- self.dropout_rate = dropout_rate
- self.layer_norm_epsilon = layer_norm_epsilon
- self.initializer_factor = initializer_factor
- self.feed_forward_proj = feed_forward_proj
- self.encoder_attention_type = encoder_attention_type
- self.use_cache = use_cache
- act_info = self.feed_forward_proj.split("-")
- self.dense_act_fn = act_info[-1]
- self.is_gated_act = act_info[0] == "gated"
- if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2:
- raise ValueError(
- f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer. "
- "Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. "
- "'gated-gelu' or 'relu'"
- )
- # for backwards compatibility
- if feed_forward_proj == "gated-gelu":
- self.dense_act_fn = "gelu_new"
- super().__init__(
- pad_token_id=pad_token_id,
- eos_token_id=eos_token_id,
- is_encoder_decoder=is_encoder_decoder,
- **kwargs,
- )
- class LongT5OnnxConfig(OnnxSeq2SeqConfigWithPast):
- @property
- def inputs(self) -> Mapping[str, Mapping[int, str]]:
- common_inputs = {
- "input_ids": {0: "batch", 1: "encoder_sequence"},
- "attention_mask": {0: "batch", 1: "encoder_sequence"},
- }
- if self.use_past:
- common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence"
- common_inputs["decoder_input_ids"] = {0: "batch"}
- common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
- else:
- common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
- common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
- if self.use_past:
- self.fill_with_past_key_values_(common_inputs, direction="inputs")
- return common_inputs
- @property
- def default_onnx_opset(self) -> int:
- return 13
|