| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674 |
- # coding=utf-8
- # Copyright 2020, The RAG Authors and The HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """RAG Retriever model implementation."""
- import os
- import pickle
- import time
- from typing import Iterable, List, Optional, Tuple
- import numpy as np
- from ...tokenization_utils import PreTrainedTokenizer
- from ...tokenization_utils_base import BatchEncoding
- from ...utils import cached_file, is_datasets_available, is_faiss_available, logging, requires_backends, strtobool
- from .configuration_rag import RagConfig
- from .tokenization_rag import RagTokenizer
- if is_datasets_available():
- from datasets import Dataset, load_dataset, load_from_disk
- if is_faiss_available():
- import faiss
- logger = logging.get_logger(__name__)
- LEGACY_INDEX_PATH = "https://storage.googleapis.com/huggingface-nlp/datasets/wiki_dpr/"
- class Index:
- """
- A base class for the Indices encapsulated by the [`RagRetriever`].
- """
- def get_doc_dicts(self, doc_ids: np.ndarray) -> List[dict]:
- """
- Returns a list of dictionaries, containing titles and text of the retrieved documents.
- Args:
- doc_ids (`np.ndarray` of shape `(batch_size, n_docs)`):
- A tensor of document indices.
- """
- raise NotImplementedError
- def get_top_docs(self, question_hidden_states: np.ndarray, n_docs=5) -> Tuple[np.ndarray, np.ndarray]:
- """
- For each query in the batch, retrieves `n_docs` documents.
- Args:
- question_hidden_states (`np.ndarray` of shape `(batch_size, vector_size)`):
- An array of query vectors.
- n_docs (`int`):
- The number of docs retrieved per query.
- Returns:
- `np.ndarray` of shape `(batch_size, n_docs)`: A tensor of indices of retrieved documents. `np.ndarray` of
- shape `(batch_size, vector_size)`: A tensor of vector representations of retrieved documents.
- """
- raise NotImplementedError
- def is_initialized(self):
- """
- Returns `True` if index is already initialized.
- """
- raise NotImplementedError
- def init_index(self):
- """
- A function responsible for loading the index into memory. Should be called only once per training run of a RAG
- model. E.g. if the model is trained on multiple GPUs in a distributed setup, only one of the workers will load
- the index.
- """
- raise NotImplementedError
- class LegacyIndex(Index):
- """
- An index which can be deserialized from the files built using https://github.com/facebookresearch/DPR. We use
- default faiss index parameters as specified in that repository.
- Args:
- vector_size (`int`):
- The dimension of indexed vectors.
- index_path (`str`):
- A path to a *directory* containing index files compatible with [`~models.rag.retrieval_rag.LegacyIndex`]
- """
- INDEX_FILENAME = "hf_bert_base.hnswSQ8_correct_phi_128.c_index"
- PASSAGE_FILENAME = "psgs_w100.tsv.pkl"
- def __init__(self, vector_size, index_path):
- self.index_id_to_db_id = []
- self.index_path = index_path
- self.passages = self._load_passages()
- self.vector_size = vector_size
- self.index = None
- self._index_initialized = False
- def _resolve_path(self, index_path, filename):
- is_local = os.path.isdir(index_path)
- try:
- # Load from URL or cache if already cached
- resolved_archive_file = cached_file(index_path, filename)
- except EnvironmentError:
- msg = (
- f"Can't load '{filename}'. Make sure that:\n\n"
- f"- '{index_path}' is a correct remote path to a directory containing a file named {filename}\n\n"
- f"- or '{index_path}' is the correct path to a directory containing a file named {filename}.\n\n"
- )
- raise EnvironmentError(msg)
- if is_local:
- logger.info(f"loading file {resolved_archive_file}")
- else:
- logger.info(f"loading file {filename} from cache at {resolved_archive_file}")
- return resolved_archive_file
- def _load_passages(self):
- logger.info(f"Loading passages from {self.index_path}")
- passages_path = self._resolve_path(self.index_path, self.PASSAGE_FILENAME)
- if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")):
- raise ValueError(
- "This part uses `pickle.load` which is insecure and will execute arbitrary code that is potentially "
- "malicious. It's recommended to never unpickle data that could have come from an untrusted source, or "
- "that could have been tampered with. If you already verified the pickle data and decided to use it, "
- "you can set the environment variable `TRUST_REMOTE_CODE` to `True` to allow it."
- )
- with open(passages_path, "rb") as passages_file:
- passages = pickle.load(passages_file)
- return passages
- def _deserialize_index(self):
- logger.info(f"Loading index from {self.index_path}")
- resolved_index_path = self._resolve_path(self.index_path, self.INDEX_FILENAME + ".index.dpr")
- self.index = faiss.read_index(resolved_index_path)
- resolved_meta_path = self._resolve_path(self.index_path, self.INDEX_FILENAME + ".index_meta.dpr")
- if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")):
- raise ValueError(
- "This part uses `pickle.load` which is insecure and will execute arbitrary code that is potentially "
- "malicious. It's recommended to never unpickle data that could have come from an untrusted source, or "
- "that could have been tampered with. If you already verified the pickle data and decided to use it, "
- "you can set the environment variable `TRUST_REMOTE_CODE` to `True` to allow it."
- )
- with open(resolved_meta_path, "rb") as metadata_file:
- self.index_id_to_db_id = pickle.load(metadata_file)
- assert (
- len(self.index_id_to_db_id) == self.index.ntotal
- ), "Deserialized index_id_to_db_id should match faiss index size"
- def is_initialized(self):
- return self._index_initialized
- def init_index(self):
- index = faiss.IndexHNSWFlat(self.vector_size + 1, 512)
- index.hnsw.efSearch = 128
- index.hnsw.efConstruction = 200
- self.index = index
- self._deserialize_index()
- self._index_initialized = True
- def get_doc_dicts(self, doc_ids: np.array):
- doc_list = []
- for doc_ids_i in doc_ids:
- ids = [str(int(doc_id)) for doc_id in doc_ids_i]
- docs = [self.passages[doc_id] for doc_id in ids]
- doc_list.append(docs)
- doc_dicts = []
- for docs in doc_list:
- doc_dict = {}
- doc_dict["title"] = [doc[1] for doc in docs]
- doc_dict["text"] = [doc[0] for doc in docs]
- doc_dicts.append(doc_dict)
- return doc_dicts
- def get_top_docs(self, question_hidden_states: np.ndarray, n_docs=5) -> Tuple[np.ndarray, np.ndarray]:
- aux_dim = np.zeros(len(question_hidden_states), dtype="float32").reshape(-1, 1)
- query_nhsw_vectors = np.hstack((question_hidden_states, aux_dim))
- _, docs_ids = self.index.search(query_nhsw_vectors, n_docs)
- vectors = [[self.index.reconstruct(int(doc_id))[:-1] for doc_id in doc_ids] for doc_ids in docs_ids]
- ids = [[int(self.index_id_to_db_id[doc_id]) for doc_id in doc_ids] for doc_ids in docs_ids]
- return np.array(ids), np.array(vectors)
- class HFIndexBase(Index):
- def __init__(self, vector_size, dataset, index_initialized=False):
- self.vector_size = vector_size
- self.dataset = dataset
- self._index_initialized = index_initialized
- self._check_dataset_format(with_index=index_initialized)
- dataset.set_format("numpy", columns=["embeddings"], output_all_columns=True, dtype="float32")
- def _check_dataset_format(self, with_index: bool):
- if not isinstance(self.dataset, Dataset):
- raise TypeError(f"Dataset should be a datasets.Dataset object, but got {type(self.dataset)}")
- if len({"title", "text", "embeddings"} - set(self.dataset.column_names)) > 0:
- raise ValueError(
- "Dataset should be a dataset with the following columns: "
- "title (str), text (str) and embeddings (arrays of dimension vector_size), "
- f"but got columns {self.dataset.column_names}"
- )
- if with_index and "embeddings" not in self.dataset.list_indexes():
- raise ValueError(
- "Missing faiss index in the dataset. Make sure you called `dataset.add_faiss_index` to compute it "
- "or `dataset.load_faiss_index` to load one from the disk."
- )
- def init_index(self):
- raise NotImplementedError()
- def is_initialized(self):
- return self._index_initialized
- def get_doc_dicts(self, doc_ids: np.ndarray) -> List[dict]:
- return [self.dataset[doc_ids[i].tolist()] for i in range(doc_ids.shape[0])]
- def get_top_docs(self, question_hidden_states: np.ndarray, n_docs=5) -> Tuple[np.ndarray, np.ndarray]:
- _, ids = self.dataset.search_batch("embeddings", question_hidden_states, n_docs)
- docs = [self.dataset[[i for i in indices if i >= 0]] for indices in ids]
- vectors = [doc["embeddings"] for doc in docs]
- for i in range(len(vectors)):
- if len(vectors[i]) < n_docs:
- vectors[i] = np.vstack([vectors[i], np.zeros((n_docs - len(vectors[i]), self.vector_size))])
- return np.array(ids), np.array(vectors) # shapes (batch_size, n_docs) and (batch_size, n_docs, d)
- class CanonicalHFIndex(HFIndexBase):
- """
- A wrapper around an instance of [`~datasets.Datasets`]. If `index_path` is set to `None`, we load the pre-computed
- index available with the [`~datasets.arrow_dataset.Dataset`], otherwise, we load the index from the indicated path
- on disk.
- Args:
- vector_size (`int`): the dimension of the passages embeddings used by the index
- dataset_name (`str`, optional, defaults to `wiki_dpr`):
- A dataset identifier of the indexed dataset on HuggingFace AWS bucket (list all available datasets and ids
- with `datasets.list_datasets()`).
- dataset_split (`str`, optional, defaults to `train`)
- Which split of the `dataset` to load.
- index_name (`str`, optional, defaults to `train`)
- The index_name of the index associated with the `dataset`. The index loaded from `index_path` will be saved
- under this name.
- index_path (`str`, optional, defaults to `None`)
- The path to the serialized faiss index on disk.
- use_dummy_dataset (`bool`, optional, defaults to `False`):
- If True, use the dummy configuration of the dataset for tests.
- """
- def __init__(
- self,
- vector_size: int,
- dataset_name: str = "wiki_dpr",
- dataset_split: str = "train",
- index_name: Optional[str] = None,
- index_path: Optional[str] = None,
- use_dummy_dataset=False,
- dataset_revision=None,
- ):
- if int(index_path is None) + int(index_name is None) != 1:
- raise ValueError("Please provide `index_name` or `index_path`.")
- self.dataset_name = dataset_name
- self.dataset_split = dataset_split
- self.index_name = index_name
- self.index_path = index_path
- self.use_dummy_dataset = use_dummy_dataset
- self.dataset_revision = dataset_revision
- logger.info(f"Loading passages from {self.dataset_name}")
- dataset = load_dataset(
- self.dataset_name,
- with_index=False,
- split=self.dataset_split,
- dummy=self.use_dummy_dataset,
- revision=dataset_revision,
- )
- super().__init__(vector_size, dataset, index_initialized=False)
- def init_index(self):
- if self.index_path is not None:
- logger.info(f"Loading index from {self.index_path}")
- self.dataset.load_faiss_index("embeddings", file=self.index_path)
- else:
- logger.info(f"Loading index from {self.dataset_name} with index name {self.index_name}")
- self.dataset = load_dataset(
- self.dataset_name,
- with_embeddings=True,
- with_index=True,
- split=self.dataset_split,
- index_name=self.index_name,
- dummy=self.use_dummy_dataset,
- revision=self.dataset_revision,
- )
- self.dataset.set_format("numpy", columns=["embeddings"], output_all_columns=True)
- self._index_initialized = True
- class CustomHFIndex(HFIndexBase):
- """
- A wrapper around an instance of [`~datasets.Datasets`]. The dataset and the index are both loaded from the
- indicated paths on disk.
- Args:
- vector_size (`int`): the dimension of the passages embeddings used by the index
- dataset_path (`str`):
- The path to the serialized dataset on disk. The dataset should have 3 columns: title (str), text (str) and
- embeddings (arrays of dimension vector_size)
- index_path (`str`)
- The path to the serialized faiss index on disk.
- """
- def __init__(self, vector_size: int, dataset, index_path=None):
- super().__init__(vector_size, dataset, index_initialized=index_path is None)
- self.index_path = index_path
- @classmethod
- def load_from_disk(cls, vector_size, dataset_path, index_path):
- logger.info(f"Loading passages from {dataset_path}")
- if dataset_path is None or index_path is None:
- raise ValueError(
- "Please provide `dataset_path` and `index_path` after calling `dataset.save_to_disk(dataset_path)` "
- "and `dataset.get_index('embeddings').save(index_path)`."
- )
- dataset = load_from_disk(dataset_path)
- return cls(vector_size=vector_size, dataset=dataset, index_path=index_path)
- def init_index(self):
- if not self.is_initialized():
- logger.info(f"Loading index from {self.index_path}")
- self.dataset.load_faiss_index("embeddings", file=self.index_path)
- self._index_initialized = True
- class RagRetriever:
- """
- Retriever used to get documents from vector queries. It retrieves the documents embeddings as well as the documents
- contents, and it formats them to be used with a RagModel.
- Args:
- config ([`RagConfig`]):
- The configuration of the RAG model this Retriever is used with. Contains parameters indicating which
- `Index` to build. You can load your own custom dataset with `config.index_name="custom"` or use a canonical
- one (default) from the datasets library with `config.index_name="wiki_dpr"` for example.
- question_encoder_tokenizer ([`PreTrainedTokenizer`]):
- The tokenizer that was used to tokenize the question. It is used to decode the question and then use the
- generator_tokenizer.
- generator_tokenizer ([`PreTrainedTokenizer`]):
- The tokenizer used for the generator part of the RagModel.
- index ([`~models.rag.retrieval_rag.Index`], optional, defaults to the one defined by the configuration):
- If specified, use this index instead of the one built using the configuration
- Examples:
- ```python
- >>> # To load the default "wiki_dpr" dataset with 21M passages from wikipedia (index name is 'compressed' or 'exact')
- >>> from transformers import RagRetriever
- >>> retriever = RagRetriever.from_pretrained(
- ... "facebook/dpr-ctx_encoder-single-nq-base", dataset="wiki_dpr", index_name="compressed"
- ... )
- >>> # To load your own indexed dataset built with the datasets library. More info on how to build the indexed dataset in examples/rag/use_own_knowledge_dataset.py
- >>> from transformers import RagRetriever
- >>> dataset = (
- ... ...
- ... ) # dataset must be a datasets.Datasets object with columns "title", "text" and "embeddings", and it must have a faiss index
- >>> retriever = RagRetriever.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base", indexed_dataset=dataset)
- >>> # To load your own indexed dataset built with the datasets library that was saved on disk. More info in examples/rag/use_own_knowledge_dataset.py
- >>> from transformers import RagRetriever
- >>> dataset_path = "path/to/my/dataset" # dataset saved via *dataset.save_to_disk(...)*
- >>> index_path = "path/to/my/index.faiss" # faiss index saved via *dataset.get_index("embeddings").save(...)*
- >>> retriever = RagRetriever.from_pretrained(
- ... "facebook/dpr-ctx_encoder-single-nq-base",
- ... index_name="custom",
- ... passages_path=dataset_path,
- ... index_path=index_path,
- ... )
- >>> # To load the legacy index built originally for Rag's paper
- >>> from transformers import RagRetriever
- >>> retriever = RagRetriever.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base", index_name="legacy")
- ```"""
- def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, index=None, init_retrieval=True):
- self._init_retrieval = init_retrieval
- requires_backends(self, ["datasets", "faiss"])
- super().__init__()
- self.index = index or self._build_index(config)
- self.generator_tokenizer = generator_tokenizer
- self.question_encoder_tokenizer = question_encoder_tokenizer
- self.n_docs = config.n_docs
- self.batch_size = config.retrieval_batch_size
- self.config = config
- if self._init_retrieval:
- self.init_retrieval()
- self.ctx_encoder_tokenizer = None
- self.return_tokenized_docs = False
- @staticmethod
- def _build_index(config):
- if config.index_name == "legacy":
- return LegacyIndex(
- config.retrieval_vector_size,
- config.index_path or LEGACY_INDEX_PATH,
- )
- elif config.index_name == "custom":
- return CustomHFIndex.load_from_disk(
- vector_size=config.retrieval_vector_size,
- dataset_path=config.passages_path,
- index_path=config.index_path,
- )
- else:
- return CanonicalHFIndex(
- vector_size=config.retrieval_vector_size,
- dataset_name=config.dataset,
- dataset_split=config.dataset_split,
- index_name=config.index_name,
- index_path=config.index_path,
- use_dummy_dataset=config.use_dummy_dataset,
- dataset_revision=config.dataset_revision,
- )
- @classmethod
- def from_pretrained(cls, retriever_name_or_path, indexed_dataset=None, **kwargs):
- requires_backends(cls, ["datasets", "faiss"])
- config = kwargs.pop("config", None) or RagConfig.from_pretrained(retriever_name_or_path, **kwargs)
- rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path, config=config)
- question_encoder_tokenizer = rag_tokenizer.question_encoder
- generator_tokenizer = rag_tokenizer.generator
- if indexed_dataset is not None:
- config.index_name = "custom"
- index = CustomHFIndex(config.retrieval_vector_size, indexed_dataset)
- else:
- index = cls._build_index(config)
- return cls(
- config,
- question_encoder_tokenizer=question_encoder_tokenizer,
- generator_tokenizer=generator_tokenizer,
- index=index,
- )
- def save_pretrained(self, save_directory):
- if isinstance(self.index, CustomHFIndex):
- if self.config.index_path is None:
- index_path = os.path.join(save_directory, "hf_dataset_index.faiss")
- self.index.dataset.get_index("embeddings").save(index_path)
- self.config.index_path = index_path
- if self.config.passages_path is None:
- passages_path = os.path.join(save_directory, "hf_dataset")
- # datasets don't support save_to_disk with indexes right now
- faiss_index = self.index.dataset._indexes.pop("embeddings")
- self.index.dataset.save_to_disk(passages_path)
- self.index.dataset._indexes["embeddings"] = faiss_index
- self.config.passages_path = passages_path
- self.config.save_pretrained(save_directory)
- rag_tokenizer = RagTokenizer(
- question_encoder=self.question_encoder_tokenizer,
- generator=self.generator_tokenizer,
- )
- rag_tokenizer.save_pretrained(save_directory)
- def init_retrieval(self):
- """
- Retriever initialization function. It loads the index into memory.
- """
- logger.info("initializing retrieval")
- self.index.init_index()
- def postprocess_docs(self, docs, input_strings, prefix, n_docs, return_tensors=None):
- r"""
- Postprocessing retrieved `docs` and combining them with `input_strings`.
- Args:
- docs (`dict`):
- Retrieved documents.
- input_strings (`str`):
- Input strings decoded by `preprocess_query`.
- prefix (`str`):
- Prefix added at the beginning of each input, typically used with T5-based models.
- Return:
- `tuple(tensors)`: a tuple consisting of two elements: contextualized `input_ids` and a compatible
- `attention_mask`.
- """
- def cat_input_and_doc(doc_title, doc_text, input_string, prefix):
- # TODO(Patrick): if we train more RAG models, I want to put the input first to take advantage of effortless truncation
- # TODO(piktus): better handling of truncation
- if doc_title.startswith('"'):
- doc_title = doc_title[1:]
- if doc_title.endswith('"'):
- doc_title = doc_title[:-1]
- if prefix is None:
- prefix = ""
- out = (prefix + doc_title + self.config.title_sep + doc_text + self.config.doc_sep + input_string).replace(
- " ", " "
- )
- return out
- rag_input_strings = [
- cat_input_and_doc(
- docs[i]["title"][j],
- docs[i]["text"][j],
- input_strings[i],
- prefix,
- )
- for i in range(len(docs))
- for j in range(n_docs)
- ]
- contextualized_inputs = self.generator_tokenizer.batch_encode_plus(
- rag_input_strings,
- max_length=self.config.max_combined_length,
- return_tensors=return_tensors,
- padding="max_length",
- truncation=True,
- )
- return contextualized_inputs["input_ids"], contextualized_inputs["attention_mask"]
- def _chunk_tensor(self, t: Iterable, chunk_size: int) -> List[Iterable]:
- return [t[i : i + chunk_size] for i in range(0, len(t), chunk_size)]
- def _main_retrieve(self, question_hidden_states: np.ndarray, n_docs: int) -> Tuple[np.ndarray, np.ndarray]:
- question_hidden_states_batched = self._chunk_tensor(question_hidden_states, self.batch_size)
- ids_batched = []
- vectors_batched = []
- for question_hidden_states in question_hidden_states_batched:
- start_time = time.time()
- ids, vectors = self.index.get_top_docs(question_hidden_states, n_docs)
- logger.debug(
- f"index search time: {time.time() - start_time} sec, batch size {question_hidden_states.shape}"
- )
- ids_batched.extend(ids)
- vectors_batched.extend(vectors)
- return (
- np.array(ids_batched),
- np.array(vectors_batched),
- ) # shapes (batch_size, n_docs) and (batch_size, n_docs, d)
- def retrieve(self, question_hidden_states: np.ndarray, n_docs: int) -> Tuple[np.ndarray, List[dict]]:
- """
- Retrieves documents for specified `question_hidden_states`.
- Args:
- question_hidden_states (`np.ndarray` of shape `(batch_size, vector_size)`):
- A batch of query vectors to retrieve with.
- n_docs (`int`):
- The number of docs retrieved per query.
- Return:
- `Tuple[np.ndarray, np.ndarray, List[dict]]`: A tuple with the following objects:
- - **retrieved_doc_embeds** (`np.ndarray` of shape `(batch_size, n_docs, dim)`) -- The retrieval embeddings
- of the retrieved docs per query.
- - **doc_ids** (`np.ndarray` of shape `(batch_size, n_docs)`) -- The ids of the documents in the index
- - **doc_dicts** (`List[dict]`): The `retrieved_doc_embeds` examples per query.
- """
- doc_ids, retrieved_doc_embeds = self._main_retrieve(question_hidden_states, n_docs)
- return retrieved_doc_embeds, doc_ids, self.index.get_doc_dicts(doc_ids)
- def set_ctx_encoder_tokenizer(self, ctx_encoder_tokenizer: PreTrainedTokenizer):
- # used in end2end retriever training
- self.ctx_encoder_tokenizer = ctx_encoder_tokenizer
- self.return_tokenized_docs = True
- def __call__(
- self,
- question_input_ids: List[List[int]],
- question_hidden_states: np.ndarray,
- prefix=None,
- n_docs=None,
- return_tensors=None,
- ) -> BatchEncoding:
- """
- Retrieves documents for specified `question_hidden_states`.
- Args:
- question_input_ids (`List[List[int]]`) batch of input ids
- question_hidden_states (`np.ndarray` of shape `(batch_size, vector_size)`:
- A batch of query vectors to retrieve with.
- prefix (`str`, *optional*):
- The prefix used by the generator's tokenizer.
- n_docs (`int`, *optional*):
- The number of docs retrieved per query.
- return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to "pt"):
- If set, will return tensors instead of list of python integers. Acceptable values are:
- - `'tf'`: Return TensorFlow `tf.constant` objects.
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return Numpy `np.ndarray` objects.
- Returns: [`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
- - **context_input_ids** -- List of token ids to be fed to a model.
- [What are input IDs?](../glossary#input-ids)
- - **context_attention_mask** -- List of indices specifying which tokens should be attended to by the model
- (when `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names`).
- [What are attention masks?](../glossary#attention-mask)
- - **retrieved_doc_embeds** -- List of embeddings of the retrieved documents
- - **doc_ids** -- List of ids of the retrieved documents
- """
- n_docs = n_docs if n_docs is not None else self.n_docs
- prefix = prefix if prefix is not None else self.config.generator.prefix
- retrieved_doc_embeds, doc_ids, docs = self.retrieve(question_hidden_states, n_docs)
- input_strings = self.question_encoder_tokenizer.batch_decode(question_input_ids, skip_special_tokens=True)
- context_input_ids, context_attention_mask = self.postprocess_docs(
- docs, input_strings, prefix, n_docs, return_tensors=return_tensors
- )
- if self.return_tokenized_docs:
- retrieved_doc_text = []
- retrieved_doc_title = []
- for b_idx in range(len(docs)):
- for doc_idx in range(n_docs):
- retrieved_doc_text.append(docs[b_idx]["text"][doc_idx])
- retrieved_doc_title.append(docs[b_idx]["title"][doc_idx])
- tokenized_docs = self.ctx_encoder_tokenizer(
- retrieved_doc_title,
- retrieved_doc_text,
- truncation=True,
- padding="longest",
- return_tensors=return_tensors,
- )
- return BatchEncoding(
- {
- "context_input_ids": context_input_ids,
- "context_attention_mask": context_attention_mask,
- "retrieved_doc_embeds": retrieved_doc_embeds,
- "doc_ids": doc_ids,
- "tokenized_doc_ids": tokenized_docs["input_ids"],
- "tokenized_doc_attention_mask": tokenized_docs["attention_mask"],
- },
- tensor_type=return_tensors,
- )
- else:
- return BatchEncoding(
- {
- "context_input_ids": context_input_ids,
- "context_attention_mask": context_attention_mask,
- "retrieved_doc_embeds": retrieved_doc_embeds,
- "doc_ids": doc_ids,
- },
- tensor_type=return_tensors,
- )
|