configuration_rag.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. # coding=utf-8
  2. # Copyright 2020, The RAG Authors and The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """RAG model configuration"""
  16. from ...configuration_utils import PretrainedConfig
  17. from ...utils import add_start_docstrings
  18. RAG_CONFIG_DOC = r"""
  19. [`RagConfig`] stores the configuration of a *RagModel*. Configuration objects inherit from [`PretrainedConfig`] and
  20. can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information.
  21. Args:
  22. title_sep (`str`, *optional*, defaults to `" / "`):
  23. Separator inserted between the title and the text of the retrieved document when calling [`RagRetriever`].
  24. doc_sep (`str`, *optional*, defaults to `" // "`):
  25. Separator inserted between the text of the retrieved document and the original input when calling
  26. [`RagRetriever`].
  27. n_docs (`int`, *optional*, defaults to 5):
  28. Number of documents to retrieve.
  29. max_combined_length (`int`, *optional*, defaults to 300):
  30. Max length of contextualized input returned by [`~RagRetriever.__call__`].
  31. retrieval_vector_size (`int`, *optional*, defaults to 768):
  32. Dimensionality of the document embeddings indexed by [`RagRetriever`].
  33. retrieval_batch_size (`int`, *optional*, defaults to 8):
  34. Retrieval batch size, defined as the number of queries issues concurrently to the faiss index encapsulated
  35. [`RagRetriever`].
  36. dataset (`str`, *optional*, defaults to `"wiki_dpr"`):
  37. A dataset identifier of the indexed dataset in HuggingFace Datasets (list all available datasets and ids
  38. using `datasets.list_datasets()`).
  39. dataset_split (`str`, *optional*, defaults to `"train"`)
  40. Which split of the `dataset` to load.
  41. index_name (`str`, *optional*, defaults to `"compressed"`)
  42. The index name of the index associated with the `dataset`. One can choose between `"legacy"`, `"exact"` and
  43. `"compressed"`.
  44. index_path (`str`, *optional*)
  45. The path to the serialized faiss index on disk.
  46. passages_path (`str`, *optional*):
  47. A path to text passages compatible with the faiss index. Required if using
  48. [`~models.rag.retrieval_rag.LegacyIndex`]
  49. use_dummy_dataset (`bool`, *optional*, defaults to `False`)
  50. Whether to load a "dummy" variant of the dataset specified by `dataset`.
  51. label_smoothing (`float`, *optional*, defaults to 0.0):
  52. Only relevant if `return_loss` is set to `True`. Controls the `epsilon` parameter value for label smoothing
  53. in the loss calculation. If set to 0, no label smoothing is performed.
  54. do_marginalize (`bool`, *optional*, defaults to `False`):
  55. If `True`, the logits are marginalized over all documents by making use of
  56. `torch.nn.functional.log_softmax`.
  57. reduce_loss (`bool`, *optional*, defaults to `False`):
  58. Whether or not to reduce the NLL loss using the `torch.Tensor.sum` operation.
  59. do_deduplication (`bool`, *optional*, defaults to `True`):
  60. Whether or not to deduplicate the generations from different context documents for a given input. Has to be
  61. set to `False` if used while training with distributed backend.
  62. exclude_bos_score (`bool`, *optional*, defaults to `False`):
  63. Whether or not to disregard the BOS token when computing the loss.
  64. output_retrieved(`bool`, *optional*, defaults to `False`):
  65. If set to `True`, `retrieved_doc_embeds`, `retrieved_doc_ids`, `context_input_ids` and
  66. `context_attention_mask` are returned. See returned tensors for more detail.
  67. use_cache (`bool`, *optional*, defaults to `True`):
  68. Whether or not the model should return the last key/values attentions (not used by all models).
  69. forced_eos_token_id (`int`, *optional*):
  70. The id of the token to force as the last generated token when `max_length` is reached. Usually set to
  71. `eos_token_id`.
  72. """
  73. @add_start_docstrings(RAG_CONFIG_DOC)
  74. class RagConfig(PretrainedConfig):
  75. model_type = "rag"
  76. is_composition = True
  77. def __init__(
  78. self,
  79. vocab_size=None,
  80. is_encoder_decoder=True,
  81. prefix=None,
  82. bos_token_id=None,
  83. pad_token_id=None,
  84. eos_token_id=None,
  85. decoder_start_token_id=None,
  86. title_sep=" / ",
  87. doc_sep=" // ",
  88. n_docs=5,
  89. max_combined_length=300,
  90. retrieval_vector_size=768,
  91. retrieval_batch_size=8,
  92. dataset="wiki_dpr",
  93. dataset_split="train",
  94. index_name="compressed",
  95. index_path=None,
  96. passages_path=None,
  97. use_dummy_dataset=False,
  98. reduce_loss=False,
  99. label_smoothing=0.0,
  100. do_deduplication=True,
  101. exclude_bos_score=False,
  102. do_marginalize=False,
  103. output_retrieved=False,
  104. use_cache=True,
  105. forced_eos_token_id=None,
  106. dataset_revision=None,
  107. **kwargs,
  108. ):
  109. super().__init__(
  110. bos_token_id=bos_token_id,
  111. pad_token_id=pad_token_id,
  112. eos_token_id=eos_token_id,
  113. decoder_start_token_id=decoder_start_token_id,
  114. forced_eos_token_id=forced_eos_token_id,
  115. is_encoder_decoder=is_encoder_decoder,
  116. prefix=prefix,
  117. vocab_size=vocab_size,
  118. **kwargs,
  119. )
  120. if "question_encoder" not in kwargs or "generator" not in kwargs:
  121. raise ValueError(
  122. f"A configuraton of type {self.model_type} cannot be instantiated because "
  123. f"both `question_encoder` and `generator` sub-configurations were not passed, only {kwargs}"
  124. )
  125. question_encoder_config = kwargs.pop("question_encoder")
  126. question_encoder_model_type = question_encoder_config.pop("model_type")
  127. decoder_config = kwargs.pop("generator")
  128. decoder_model_type = decoder_config.pop("model_type")
  129. from ..auto.configuration_auto import AutoConfig
  130. self.question_encoder = AutoConfig.for_model(question_encoder_model_type, **question_encoder_config)
  131. self.generator = AutoConfig.for_model(decoder_model_type, **decoder_config)
  132. self.reduce_loss = reduce_loss
  133. self.label_smoothing = label_smoothing
  134. self.exclude_bos_score = exclude_bos_score
  135. self.do_marginalize = do_marginalize
  136. self.title_sep = title_sep
  137. self.doc_sep = doc_sep
  138. self.n_docs = n_docs
  139. self.max_combined_length = max_combined_length
  140. self.dataset = dataset
  141. self.dataset_split = dataset_split
  142. self.index_name = index_name
  143. self.retrieval_vector_size = retrieval_vector_size
  144. self.retrieval_batch_size = retrieval_batch_size
  145. self.passages_path = passages_path
  146. self.index_path = index_path
  147. self.use_dummy_dataset = use_dummy_dataset
  148. self.dataset_revision = dataset_revision
  149. self.output_retrieved = output_retrieved
  150. self.do_deduplication = do_deduplication
  151. self.use_cache = use_cache
  152. if self.forced_eos_token_id is None:
  153. self.forced_eos_token_id = getattr(self.generator, "forced_eos_token_id", None)
  154. @classmethod
  155. def from_question_encoder_generator_configs(
  156. cls, question_encoder_config: PretrainedConfig, generator_config: PretrainedConfig, **kwargs
  157. ) -> PretrainedConfig:
  158. r"""
  159. Instantiate a [`EncoderDecoderConfig`] (or a derived class) from a pre-trained encoder model configuration and
  160. decoder model configuration.
  161. Returns:
  162. [`EncoderDecoderConfig`]: An instance of a configuration object
  163. """
  164. return cls(question_encoder=question_encoder_config.to_dict(), generator=generator_config.to_dict(), **kwargs)