configuration_tapas.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. # coding=utf-8
  2. # Copyright 2020 Google Research and The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """
  16. TAPAS configuration. Based on the BERT configuration with added parameters.
  17. Hyperparameters are taken from run_task_main.py and hparam_utils.py of the original implementation. URLS:
  18. - https://github.com/google-research/tapas/blob/master/tapas/run_task_main.py
  19. - https://github.com/google-research/tapas/blob/master/tapas/utils/hparam_utils.py
  20. """
  21. from ...configuration_utils import PretrainedConfig
  22. class TapasConfig(PretrainedConfig):
  23. r"""
  24. This is the configuration class to store the configuration of a [`TapasModel`]. It is used to instantiate a TAPAS
  25. model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
  26. defaults will yield a similar configuration to that of the TAPAS
  27. [google/tapas-base-finetuned-sqa](https://huggingface.co/google/tapas-base-finetuned-sqa) architecture.
  28. Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
  29. documentation from [`PretrainedConfig`] for more information.
  30. Hyperparameters additional to BERT are taken from run_task_main.py and hparam_utils.py of the original
  31. implementation. Original implementation available at https://github.com/google-research/tapas/tree/master.
  32. Args:
  33. vocab_size (`int`, *optional*, defaults to 30522):
  34. Vocabulary size of the TAPAS model. Defines the number of different tokens that can be represented by the
  35. `inputs_ids` passed when calling [`TapasModel`].
  36. hidden_size (`int`, *optional*, defaults to 768):
  37. Dimensionality of the encoder layers and the pooler layer.
  38. num_hidden_layers (`int`, *optional*, defaults to 12):
  39. Number of hidden layers in the Transformer encoder.
  40. num_attention_heads (`int`, *optional*, defaults to 12):
  41. Number of attention heads for each attention layer in the Transformer encoder.
  42. intermediate_size (`int`, *optional*, defaults to 3072):
  43. Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
  44. hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
  45. The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
  46. `"relu"`, `"swish"` and `"gelu_new"` are supported.
  47. hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
  48. The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
  49. attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
  50. The dropout ratio for the attention probabilities.
  51. max_position_embeddings (`int`, *optional*, defaults to 1024):
  52. The maximum sequence length that this model might ever be used with. Typically set this to something large
  53. just in case (e.g., 512 or 1024 or 2048).
  54. type_vocab_sizes (`List[int]`, *optional*, defaults to `[3, 256, 256, 2, 256, 256, 10]`):
  55. The vocabulary sizes of the `token_type_ids` passed when calling [`TapasModel`].
  56. initializer_range (`float`, *optional*, defaults to 0.02):
  57. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  58. layer_norm_eps (`float`, *optional*, defaults to 1e-12):
  59. The epsilon used by the layer normalization layers.
  60. positive_label_weight (`float`, *optional*, defaults to 10.0):
  61. Weight for positive labels.
  62. num_aggregation_labels (`int`, *optional*, defaults to 0):
  63. The number of aggregation operators to predict.
  64. aggregation_loss_weight (`float`, *optional*, defaults to 1.0):
  65. Importance weight for the aggregation loss.
  66. use_answer_as_supervision (`bool`, *optional*):
  67. Whether to use the answer as the only supervision for aggregation examples.
  68. answer_loss_importance (`float`, *optional*, defaults to 1.0):
  69. Importance weight for the regression loss.
  70. use_normalized_answer_loss (`bool`, *optional*, defaults to `False`):
  71. Whether to normalize the answer loss by the maximum of the predicted and expected value.
  72. huber_loss_delta (`float`, *optional*):
  73. Delta parameter used to calculate the regression loss.
  74. temperature (`float`, *optional*, defaults to 1.0):
  75. Value used to control (OR change) the skewness of cell logits probabilities.
  76. aggregation_temperature (`float`, *optional*, defaults to 1.0):
  77. Scales aggregation logits to control the skewness of probabilities.
  78. use_gumbel_for_cells (`bool`, *optional*, defaults to `False`):
  79. Whether to apply Gumbel-Softmax to cell selection.
  80. use_gumbel_for_aggregation (`bool`, *optional*, defaults to `False`):
  81. Whether to apply Gumbel-Softmax to aggregation selection.
  82. average_approximation_function (`string`, *optional*, defaults to `"ratio"`):
  83. Method to calculate the expected average of cells in the weak supervision case. One of `"ratio"`,
  84. `"first_order"` or `"second_order"`.
  85. cell_selection_preference (`float`, *optional*):
  86. Preference for cell selection in ambiguous cases. Only applicable in case of weak supervision for
  87. aggregation (WTQ, WikiSQL). If the total mass of the aggregation probabilities (excluding the "NONE"
  88. operator) is higher than this hyperparameter, then aggregation is predicted for an example.
  89. answer_loss_cutoff (`float`, *optional*):
  90. Ignore examples with answer loss larger than cutoff.
  91. max_num_rows (`int`, *optional*, defaults to 64):
  92. Maximum number of rows.
  93. max_num_columns (`int`, *optional*, defaults to 32):
  94. Maximum number of columns.
  95. average_logits_per_cell (`bool`, *optional*, defaults to `False`):
  96. Whether to average logits per cell.
  97. select_one_column (`bool`, *optional*, defaults to `True`):
  98. Whether to constrain the model to only select cells from a single column.
  99. allow_empty_column_selection (`bool`, *optional*, defaults to `False`):
  100. Whether to allow not to select any column.
  101. init_cell_selection_weights_to_zero (`bool`, *optional*, defaults to `False`):
  102. Whether to initialize cell selection weights to 0 so that the initial probabilities are 50%.
  103. reset_position_index_per_cell (`bool`, *optional*, defaults to `True`):
  104. Whether to restart position indexes at every cell (i.e. use relative position embeddings).
  105. disable_per_token_loss (`bool`, *optional*, defaults to `False`):
  106. Whether to disable any (strong or weak) supervision on cells.
  107. aggregation_labels (`Dict[int, label]`, *optional*):
  108. The aggregation labels used to aggregate the results. For example, the WTQ models have the following
  109. aggregation labels: `{0: "NONE", 1: "SUM", 2: "AVERAGE", 3: "COUNT"}`
  110. no_aggregation_label_index (`int`, *optional*):
  111. If the aggregation labels are defined and one of these labels represents "No aggregation", this should be
  112. set to its index. For example, the WTQ models have the "NONE" aggregation label at index 0, so that value
  113. should be set to 0 for these models.
  114. Example:
  115. ```python
  116. >>> from transformers import TapasModel, TapasConfig
  117. >>> # Initializing a default (SQA) Tapas configuration
  118. >>> configuration = TapasConfig()
  119. >>> # Initializing a model from the configuration
  120. >>> model = TapasModel(configuration)
  121. >>> # Accessing the model configuration
  122. >>> configuration = model.config
  123. ```"""
  124. model_type = "tapas"
  125. def __init__(
  126. self,
  127. vocab_size=30522,
  128. hidden_size=768,
  129. num_hidden_layers=12,
  130. num_attention_heads=12,
  131. intermediate_size=3072,
  132. hidden_act="gelu",
  133. hidden_dropout_prob=0.1,
  134. attention_probs_dropout_prob=0.1,
  135. max_position_embeddings=1024,
  136. type_vocab_sizes=[3, 256, 256, 2, 256, 256, 10],
  137. initializer_range=0.02,
  138. layer_norm_eps=1e-12,
  139. pad_token_id=0,
  140. positive_label_weight=10.0,
  141. num_aggregation_labels=0,
  142. aggregation_loss_weight=1.0,
  143. use_answer_as_supervision=None,
  144. answer_loss_importance=1.0,
  145. use_normalized_answer_loss=False,
  146. huber_loss_delta=None,
  147. temperature=1.0,
  148. aggregation_temperature=1.0,
  149. use_gumbel_for_cells=False,
  150. use_gumbel_for_aggregation=False,
  151. average_approximation_function="ratio",
  152. cell_selection_preference=None,
  153. answer_loss_cutoff=None,
  154. max_num_rows=64,
  155. max_num_columns=32,
  156. average_logits_per_cell=False,
  157. select_one_column=True,
  158. allow_empty_column_selection=False,
  159. init_cell_selection_weights_to_zero=False,
  160. reset_position_index_per_cell=True,
  161. disable_per_token_loss=False,
  162. aggregation_labels=None,
  163. no_aggregation_label_index=None,
  164. **kwargs,
  165. ):
  166. super().__init__(pad_token_id=pad_token_id, **kwargs)
  167. # BERT hyperparameters (with updated max_position_embeddings and type_vocab_sizes)
  168. self.vocab_size = vocab_size
  169. self.hidden_size = hidden_size
  170. self.num_hidden_layers = num_hidden_layers
  171. self.num_attention_heads = num_attention_heads
  172. self.hidden_act = hidden_act
  173. self.intermediate_size = intermediate_size
  174. self.hidden_dropout_prob = hidden_dropout_prob
  175. self.attention_probs_dropout_prob = attention_probs_dropout_prob
  176. self.max_position_embeddings = max_position_embeddings
  177. self.type_vocab_sizes = type_vocab_sizes
  178. self.initializer_range = initializer_range
  179. self.layer_norm_eps = layer_norm_eps
  180. # Fine-tuning task hyperparameters
  181. self.positive_label_weight = positive_label_weight
  182. self.num_aggregation_labels = num_aggregation_labels
  183. self.aggregation_loss_weight = aggregation_loss_weight
  184. self.use_answer_as_supervision = use_answer_as_supervision
  185. self.answer_loss_importance = answer_loss_importance
  186. self.use_normalized_answer_loss = use_normalized_answer_loss
  187. self.huber_loss_delta = huber_loss_delta
  188. self.temperature = temperature
  189. self.aggregation_temperature = aggregation_temperature
  190. self.use_gumbel_for_cells = use_gumbel_for_cells
  191. self.use_gumbel_for_aggregation = use_gumbel_for_aggregation
  192. self.average_approximation_function = average_approximation_function
  193. self.cell_selection_preference = cell_selection_preference
  194. self.answer_loss_cutoff = answer_loss_cutoff
  195. self.max_num_rows = max_num_rows
  196. self.max_num_columns = max_num_columns
  197. self.average_logits_per_cell = average_logits_per_cell
  198. self.select_one_column = select_one_column
  199. self.allow_empty_column_selection = allow_empty_column_selection
  200. self.init_cell_selection_weights_to_zero = init_cell_selection_weights_to_zero
  201. self.reset_position_index_per_cell = reset_position_index_per_cell
  202. self.disable_per_token_loss = disable_per_token_loss
  203. # Aggregation hyperparameters
  204. self.aggregation_labels = aggregation_labels
  205. self.no_aggregation_label_index = no_aggregation_label_index
  206. if isinstance(self.aggregation_labels, dict):
  207. self.aggregation_labels = {int(k): v for k, v in aggregation_labels.items()}