configuration_bart.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. # coding=utf-8
  2. # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """BART model configuration"""
  16. import warnings
  17. from collections import OrderedDict
  18. from typing import Any, Mapping, Optional
  19. from ... import PreTrainedTokenizer
  20. from ...configuration_utils import PretrainedConfig
  21. from ...onnx import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast
  22. from ...onnx.utils import compute_effective_axis_dimension
  23. from ...utils import TensorType, is_torch_available, logging
  24. logger = logging.get_logger(__name__)
  25. class BartConfig(PretrainedConfig):
  26. r"""
  27. This is the configuration class to store the configuration of a [`BartModel`]. It is used to instantiate a BART
  28. model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
  29. defaults will yield a similar configuration to that of the BART
  30. [facebook/bart-large](https://huggingface.co/facebook/bart-large) architecture.
  31. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  32. documentation from [`PretrainedConfig`] for more information.
  33. Args:
  34. vocab_size (`int`, *optional*, defaults to 50265):
  35. Vocabulary size of the BART model. Defines the number of different tokens that can be represented by the
  36. `inputs_ids` passed when calling [`BartModel`] or [`TFBartModel`].
  37. d_model (`int`, *optional*, defaults to 1024):
  38. Dimensionality of the layers and the pooler layer.
  39. encoder_layers (`int`, *optional*, defaults to 12):
  40. Number of encoder layers.
  41. decoder_layers (`int`, *optional*, defaults to 12):
  42. Number of decoder layers.
  43. encoder_attention_heads (`int`, *optional*, defaults to 16):
  44. Number of attention heads for each attention layer in the Transformer encoder.
  45. decoder_attention_heads (`int`, *optional*, defaults to 16):
  46. Number of attention heads for each attention layer in the Transformer decoder.
  47. decoder_ffn_dim (`int`, *optional*, defaults to 4096):
  48. Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
  49. encoder_ffn_dim (`int`, *optional*, defaults to 4096):
  50. Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
  51. activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
  52. The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
  53. `"relu"`, `"silu"` and `"gelu_new"` are supported.
  54. dropout (`float`, *optional*, defaults to 0.1):
  55. The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
  56. attention_dropout (`float`, *optional*, defaults to 0.0):
  57. The dropout ratio for the attention probabilities.
  58. activation_dropout (`float`, *optional*, defaults to 0.0):
  59. The dropout ratio for activations inside the fully connected layer.
  60. classifier_dropout (`float`, *optional*, defaults to 0.0):
  61. The dropout ratio for classifier.
  62. max_position_embeddings (`int`, *optional*, defaults to 1024):
  63. The maximum sequence length that this model might ever be used with. Typically set this to something large
  64. just in case (e.g., 512 or 1024 or 2048).
  65. init_std (`float`, *optional*, defaults to 0.02):
  66. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  67. encoder_layerdrop (`float`, *optional*, defaults to 0.0):
  68. The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
  69. for more details.
  70. decoder_layerdrop (`float`, *optional*, defaults to 0.0):
  71. The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
  72. for more details.
  73. scale_embedding (`bool`, *optional*, defaults to `False`):
  74. Scale embeddings by diving by sqrt(d_model).
  75. use_cache (`bool`, *optional*, defaults to `True`):
  76. Whether or not the model should return the last key/values attentions (not used by all models).
  77. num_labels (`int`, *optional*, defaults to 3):
  78. The number of labels to use in [`BartForSequenceClassification`].
  79. forced_eos_token_id (`int`, *optional*, defaults to 2):
  80. The id of the token to force as the last generated token when `max_length` is reached. Usually set to
  81. `eos_token_id`.
  82. Example:
  83. ```python
  84. >>> from transformers import BartConfig, BartModel
  85. >>> # Initializing a BART facebook/bart-large style configuration
  86. >>> configuration = BartConfig()
  87. >>> # Initializing a model (with random weights) from the facebook/bart-large style configuration
  88. >>> model = BartModel(configuration)
  89. >>> # Accessing the model configuration
  90. >>> configuration = model.config
  91. ```"""
  92. model_type = "bart"
  93. keys_to_ignore_at_inference = ["past_key_values"]
  94. attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
  95. def __init__(
  96. self,
  97. vocab_size=50265,
  98. max_position_embeddings=1024,
  99. encoder_layers=12,
  100. encoder_ffn_dim=4096,
  101. encoder_attention_heads=16,
  102. decoder_layers=12,
  103. decoder_ffn_dim=4096,
  104. decoder_attention_heads=16,
  105. encoder_layerdrop=0.0,
  106. decoder_layerdrop=0.0,
  107. activation_function="gelu",
  108. d_model=1024,
  109. dropout=0.1,
  110. attention_dropout=0.0,
  111. activation_dropout=0.0,
  112. init_std=0.02,
  113. classifier_dropout=0.0,
  114. scale_embedding=False,
  115. use_cache=True,
  116. num_labels=3,
  117. pad_token_id=1,
  118. bos_token_id=0,
  119. eos_token_id=2,
  120. is_encoder_decoder=True,
  121. decoder_start_token_id=2,
  122. forced_eos_token_id=2,
  123. **kwargs,
  124. ):
  125. self.vocab_size = vocab_size
  126. self.max_position_embeddings = max_position_embeddings
  127. self.d_model = d_model
  128. self.encoder_ffn_dim = encoder_ffn_dim
  129. self.encoder_layers = encoder_layers
  130. self.encoder_attention_heads = encoder_attention_heads
  131. self.decoder_ffn_dim = decoder_ffn_dim
  132. self.decoder_layers = decoder_layers
  133. self.decoder_attention_heads = decoder_attention_heads
  134. self.dropout = dropout
  135. self.attention_dropout = attention_dropout
  136. self.activation_dropout = activation_dropout
  137. self.activation_function = activation_function
  138. self.init_std = init_std
  139. self.encoder_layerdrop = encoder_layerdrop
  140. self.decoder_layerdrop = decoder_layerdrop
  141. self.classifier_dropout = classifier_dropout
  142. self.use_cache = use_cache
  143. self.num_hidden_layers = encoder_layers
  144. self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
  145. super().__init__(
  146. num_labels=num_labels,
  147. pad_token_id=pad_token_id,
  148. bos_token_id=bos_token_id,
  149. eos_token_id=eos_token_id,
  150. is_encoder_decoder=is_encoder_decoder,
  151. decoder_start_token_id=decoder_start_token_id,
  152. forced_eos_token_id=forced_eos_token_id,
  153. **kwargs,
  154. )
  155. # ensure backward compatibility for BART CNN models
  156. if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
  157. self.forced_bos_token_id = self.bos_token_id
  158. warnings.warn(
  159. f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. "
  160. "The config can simply be saved and uploaded again to be fixed."
  161. )
  162. class BartOnnxConfig(OnnxSeq2SeqConfigWithPast):
  163. @property
  164. def inputs(self) -> Mapping[str, Mapping[int, str]]:
  165. if self.task in ["default", "seq2seq-lm"]:
  166. common_inputs = OrderedDict(
  167. [
  168. ("input_ids", {0: "batch", 1: "encoder_sequence"}),
  169. ("attention_mask", {0: "batch", 1: "encoder_sequence"}),
  170. ]
  171. )
  172. if self.use_past:
  173. common_inputs["decoder_input_ids"] = {0: "batch"}
  174. common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
  175. else:
  176. common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
  177. common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
  178. if self.use_past:
  179. self.fill_with_past_key_values_(common_inputs, direction="inputs")
  180. elif self.task == "causal-lm":
  181. # TODO: figure this case out.
  182. common_inputs = OrderedDict(
  183. [
  184. ("input_ids", {0: "batch", 1: "encoder_sequence"}),
  185. ("attention_mask", {0: "batch", 1: "encoder_sequence"}),
  186. ]
  187. )
  188. if self.use_past:
  189. num_encoder_layers, _ = self.num_layers
  190. for i in range(num_encoder_layers):
  191. common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
  192. common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
  193. else:
  194. common_inputs = OrderedDict(
  195. [
  196. ("input_ids", {0: "batch", 1: "encoder_sequence"}),
  197. ("attention_mask", {0: "batch", 1: "encoder_sequence"}),
  198. ("decoder_input_ids", {0: "batch", 1: "decoder_sequence"}),
  199. ("decoder_attention_mask", {0: "batch", 1: "decoder_sequence"}),
  200. ]
  201. )
  202. return common_inputs
  203. @property
  204. def outputs(self) -> Mapping[str, Mapping[int, str]]:
  205. if self.task in ["default", "seq2seq-lm"]:
  206. common_outputs = super().outputs
  207. else:
  208. common_outputs = super(OnnxConfigWithPast, self).outputs
  209. if self.use_past:
  210. num_encoder_layers, _ = self.num_layers
  211. for i in range(num_encoder_layers):
  212. common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
  213. common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
  214. return common_outputs
  215. def _generate_dummy_inputs_for_default_and_seq2seq_lm(
  216. self,
  217. tokenizer: PreTrainedTokenizer,
  218. batch_size: int = -1,
  219. seq_length: int = -1,
  220. is_pair: bool = False,
  221. framework: Optional[TensorType] = None,
  222. ) -> Mapping[str, Any]:
  223. encoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
  224. tokenizer, batch_size, seq_length, is_pair, framework
  225. )
  226. # Generate decoder inputs
  227. decoder_seq_length = seq_length if not self.use_past else 1
  228. decoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
  229. tokenizer, batch_size, decoder_seq_length, is_pair, framework
  230. )
  231. decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
  232. common_inputs = dict(**encoder_inputs, **decoder_inputs)
  233. if self.use_past:
  234. if not is_torch_available():
  235. raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
  236. else:
  237. import torch
  238. batch, encoder_seq_length = common_inputs["input_ids"].shape
  239. decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
  240. num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads
  241. encoder_shape = (
  242. batch,
  243. num_encoder_attention_heads,
  244. encoder_seq_length,
  245. self._config.hidden_size // num_encoder_attention_heads,
  246. )
  247. decoder_past_length = decoder_seq_length + 3
  248. decoder_shape = (
  249. batch,
  250. num_decoder_attention_heads,
  251. decoder_past_length,
  252. self._config.hidden_size // num_decoder_attention_heads,
  253. )
  254. common_inputs["decoder_attention_mask"] = torch.cat(
  255. [common_inputs["decoder_attention_mask"], torch.ones(batch, decoder_past_length)], dim=1
  256. )
  257. common_inputs["past_key_values"] = []
  258. # If the number of encoder and decoder layers are present in the model configuration, both are considered
  259. num_encoder_layers, num_decoder_layers = self.num_layers
  260. min_num_layers = min(num_encoder_layers, num_decoder_layers)
  261. max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers
  262. remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder"
  263. for _ in range(min_num_layers):
  264. common_inputs["past_key_values"].append(
  265. (
  266. torch.zeros(decoder_shape),
  267. torch.zeros(decoder_shape),
  268. torch.zeros(encoder_shape),
  269. torch.zeros(encoder_shape),
  270. )
  271. )
  272. # TODO: test this.
  273. shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape
  274. for _ in range(min_num_layers, max_num_layers):
  275. common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape)))
  276. return common_inputs
  277. def _generate_dummy_inputs_for_causal_lm(
  278. self,
  279. tokenizer: PreTrainedTokenizer,
  280. batch_size: int = -1,
  281. seq_length: int = -1,
  282. is_pair: bool = False,
  283. framework: Optional[TensorType] = None,
  284. ) -> Mapping[str, Any]:
  285. common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
  286. tokenizer, batch_size, seq_length, is_pair, framework
  287. )
  288. if self.use_past:
  289. if not is_torch_available():
  290. raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
  291. else:
  292. import torch
  293. batch, seqlen = common_inputs["input_ids"].shape
  294. # Not using the same length for past_key_values
  295. past_key_values_length = seqlen + 2
  296. num_encoder_layers, _ = self.num_layers
  297. num_encoder_attention_heads, _ = self.num_attention_heads
  298. past_shape = (
  299. batch,
  300. num_encoder_attention_heads,
  301. past_key_values_length,
  302. self._config.hidden_size // num_encoder_attention_heads,
  303. )
  304. mask_dtype = common_inputs["attention_mask"].dtype
  305. common_inputs["attention_mask"] = torch.cat(
  306. [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
  307. )
  308. common_inputs["past_key_values"] = [
  309. (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers)
  310. ]
  311. return common_inputs
  312. def _generate_dummy_inputs_for_sequence_classification_and_question_answering(
  313. self,
  314. tokenizer: PreTrainedTokenizer,
  315. batch_size: int = -1,
  316. seq_length: int = -1,
  317. is_pair: bool = False,
  318. framework: Optional[TensorType] = None,
  319. ) -> Mapping[str, Any]:
  320. # Copied from OnnxConfig.generate_dummy_inputs
  321. # Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity.
  322. # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
  323. batch_size = compute_effective_axis_dimension(
  324. batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0
  325. )
  326. # If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
  327. token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
  328. seq_length = compute_effective_axis_dimension(
  329. seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add
  330. )
  331. # Generate dummy inputs according to compute batch and sequence
  332. dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
  333. common_inputs = dict(tokenizer(dummy_input, return_tensors=framework))
  334. return common_inputs
  335. def generate_dummy_inputs(
  336. self,
  337. tokenizer: PreTrainedTokenizer,
  338. batch_size: int = -1,
  339. seq_length: int = -1,
  340. is_pair: bool = False,
  341. framework: Optional[TensorType] = None,
  342. ) -> Mapping[str, Any]:
  343. if self.task in ["default", "seq2seq-lm"]:
  344. common_inputs = self._generate_dummy_inputs_for_default_and_seq2seq_lm(
  345. tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
  346. )
  347. elif self.task == "causal-lm":
  348. common_inputs = self._generate_dummy_inputs_for_causal_lm(
  349. tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
  350. )
  351. else:
  352. common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
  353. tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
  354. )
  355. return common_inputs
  356. def _flatten_past_key_values_(self, flattened_output, name, idx, t):
  357. if self.task in ["default", "seq2seq-lm"]:
  358. flattened_output = super()._flatten_past_key_values_(flattened_output, name, idx, t)
  359. else:
  360. flattened_output = super(OnnxSeq2SeqConfigWithPast, self)._flatten_past_key_values_(
  361. flattened_output, name, idx, t
  362. )