retrieval_rag.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674
  1. # coding=utf-8
  2. # Copyright 2020, The RAG Authors and The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """RAG Retriever model implementation."""
  16. import os
  17. import pickle
  18. import time
  19. from typing import Iterable, List, Optional, Tuple
  20. import numpy as np
  21. from ...tokenization_utils import PreTrainedTokenizer
  22. from ...tokenization_utils_base import BatchEncoding
  23. from ...utils import cached_file, is_datasets_available, is_faiss_available, logging, requires_backends, strtobool
  24. from .configuration_rag import RagConfig
  25. from .tokenization_rag import RagTokenizer
  26. if is_datasets_available():
  27. from datasets import Dataset, load_dataset, load_from_disk
  28. if is_faiss_available():
  29. import faiss
  30. logger = logging.get_logger(__name__)
  31. LEGACY_INDEX_PATH = "https://storage.googleapis.com/huggingface-nlp/datasets/wiki_dpr/"
  32. class Index:
  33. """
  34. A base class for the Indices encapsulated by the [`RagRetriever`].
  35. """
  36. def get_doc_dicts(self, doc_ids: np.ndarray) -> List[dict]:
  37. """
  38. Returns a list of dictionaries, containing titles and text of the retrieved documents.
  39. Args:
  40. doc_ids (`np.ndarray` of shape `(batch_size, n_docs)`):
  41. A tensor of document indices.
  42. """
  43. raise NotImplementedError
  44. def get_top_docs(self, question_hidden_states: np.ndarray, n_docs=5) -> Tuple[np.ndarray, np.ndarray]:
  45. """
  46. For each query in the batch, retrieves `n_docs` documents.
  47. Args:
  48. question_hidden_states (`np.ndarray` of shape `(batch_size, vector_size)`):
  49. An array of query vectors.
  50. n_docs (`int`):
  51. The number of docs retrieved per query.
  52. Returns:
  53. `np.ndarray` of shape `(batch_size, n_docs)`: A tensor of indices of retrieved documents. `np.ndarray` of
  54. shape `(batch_size, vector_size)`: A tensor of vector representations of retrieved documents.
  55. """
  56. raise NotImplementedError
  57. def is_initialized(self):
  58. """
  59. Returns `True` if index is already initialized.
  60. """
  61. raise NotImplementedError
  62. def init_index(self):
  63. """
  64. A function responsible for loading the index into memory. Should be called only once per training run of a RAG
  65. model. E.g. if the model is trained on multiple GPUs in a distributed setup, only one of the workers will load
  66. the index.
  67. """
  68. raise NotImplementedError
  69. class LegacyIndex(Index):
  70. """
  71. An index which can be deserialized from the files built using https://github.com/facebookresearch/DPR. We use
  72. default faiss index parameters as specified in that repository.
  73. Args:
  74. vector_size (`int`):
  75. The dimension of indexed vectors.
  76. index_path (`str`):
  77. A path to a *directory* containing index files compatible with [`~models.rag.retrieval_rag.LegacyIndex`]
  78. """
  79. INDEX_FILENAME = "hf_bert_base.hnswSQ8_correct_phi_128.c_index"
  80. PASSAGE_FILENAME = "psgs_w100.tsv.pkl"
  81. def __init__(self, vector_size, index_path):
  82. self.index_id_to_db_id = []
  83. self.index_path = index_path
  84. self.passages = self._load_passages()
  85. self.vector_size = vector_size
  86. self.index = None
  87. self._index_initialized = False
  88. def _resolve_path(self, index_path, filename):
  89. is_local = os.path.isdir(index_path)
  90. try:
  91. # Load from URL or cache if already cached
  92. resolved_archive_file = cached_file(index_path, filename)
  93. except EnvironmentError:
  94. msg = (
  95. f"Can't load '{filename}'. Make sure that:\n\n"
  96. f"- '{index_path}' is a correct remote path to a directory containing a file named {filename}\n\n"
  97. f"- or '{index_path}' is the correct path to a directory containing a file named {filename}.\n\n"
  98. )
  99. raise EnvironmentError(msg)
  100. if is_local:
  101. logger.info(f"loading file {resolved_archive_file}")
  102. else:
  103. logger.info(f"loading file {filename} from cache at {resolved_archive_file}")
  104. return resolved_archive_file
  105. def _load_passages(self):
  106. logger.info(f"Loading passages from {self.index_path}")
  107. passages_path = self._resolve_path(self.index_path, self.PASSAGE_FILENAME)
  108. if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")):
  109. raise ValueError(
  110. "This part uses `pickle.load` which is insecure and will execute arbitrary code that is potentially "
  111. "malicious. It's recommended to never unpickle data that could have come from an untrusted source, or "
  112. "that could have been tampered with. If you already verified the pickle data and decided to use it, "
  113. "you can set the environment variable `TRUST_REMOTE_CODE` to `True` to allow it."
  114. )
  115. with open(passages_path, "rb") as passages_file:
  116. passages = pickle.load(passages_file)
  117. return passages
  118. def _deserialize_index(self):
  119. logger.info(f"Loading index from {self.index_path}")
  120. resolved_index_path = self._resolve_path(self.index_path, self.INDEX_FILENAME + ".index.dpr")
  121. self.index = faiss.read_index(resolved_index_path)
  122. resolved_meta_path = self._resolve_path(self.index_path, self.INDEX_FILENAME + ".index_meta.dpr")
  123. if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")):
  124. raise ValueError(
  125. "This part uses `pickle.load` which is insecure and will execute arbitrary code that is potentially "
  126. "malicious. It's recommended to never unpickle data that could have come from an untrusted source, or "
  127. "that could have been tampered with. If you already verified the pickle data and decided to use it, "
  128. "you can set the environment variable `TRUST_REMOTE_CODE` to `True` to allow it."
  129. )
  130. with open(resolved_meta_path, "rb") as metadata_file:
  131. self.index_id_to_db_id = pickle.load(metadata_file)
  132. assert (
  133. len(self.index_id_to_db_id) == self.index.ntotal
  134. ), "Deserialized index_id_to_db_id should match faiss index size"
  135. def is_initialized(self):
  136. return self._index_initialized
  137. def init_index(self):
  138. index = faiss.IndexHNSWFlat(self.vector_size + 1, 512)
  139. index.hnsw.efSearch = 128
  140. index.hnsw.efConstruction = 200
  141. self.index = index
  142. self._deserialize_index()
  143. self._index_initialized = True
  144. def get_doc_dicts(self, doc_ids: np.array):
  145. doc_list = []
  146. for doc_ids_i in doc_ids:
  147. ids = [str(int(doc_id)) for doc_id in doc_ids_i]
  148. docs = [self.passages[doc_id] for doc_id in ids]
  149. doc_list.append(docs)
  150. doc_dicts = []
  151. for docs in doc_list:
  152. doc_dict = {}
  153. doc_dict["title"] = [doc[1] for doc in docs]
  154. doc_dict["text"] = [doc[0] for doc in docs]
  155. doc_dicts.append(doc_dict)
  156. return doc_dicts
  157. def get_top_docs(self, question_hidden_states: np.ndarray, n_docs=5) -> Tuple[np.ndarray, np.ndarray]:
  158. aux_dim = np.zeros(len(question_hidden_states), dtype="float32").reshape(-1, 1)
  159. query_nhsw_vectors = np.hstack((question_hidden_states, aux_dim))
  160. _, docs_ids = self.index.search(query_nhsw_vectors, n_docs)
  161. vectors = [[self.index.reconstruct(int(doc_id))[:-1] for doc_id in doc_ids] for doc_ids in docs_ids]
  162. ids = [[int(self.index_id_to_db_id[doc_id]) for doc_id in doc_ids] for doc_ids in docs_ids]
  163. return np.array(ids), np.array(vectors)
  164. class HFIndexBase(Index):
  165. def __init__(self, vector_size, dataset, index_initialized=False):
  166. self.vector_size = vector_size
  167. self.dataset = dataset
  168. self._index_initialized = index_initialized
  169. self._check_dataset_format(with_index=index_initialized)
  170. dataset.set_format("numpy", columns=["embeddings"], output_all_columns=True, dtype="float32")
  171. def _check_dataset_format(self, with_index: bool):
  172. if not isinstance(self.dataset, Dataset):
  173. raise TypeError(f"Dataset should be a datasets.Dataset object, but got {type(self.dataset)}")
  174. if len({"title", "text", "embeddings"} - set(self.dataset.column_names)) > 0:
  175. raise ValueError(
  176. "Dataset should be a dataset with the following columns: "
  177. "title (str), text (str) and embeddings (arrays of dimension vector_size), "
  178. f"but got columns {self.dataset.column_names}"
  179. )
  180. if with_index and "embeddings" not in self.dataset.list_indexes():
  181. raise ValueError(
  182. "Missing faiss index in the dataset. Make sure you called `dataset.add_faiss_index` to compute it "
  183. "or `dataset.load_faiss_index` to load one from the disk."
  184. )
  185. def init_index(self):
  186. raise NotImplementedError()
  187. def is_initialized(self):
  188. return self._index_initialized
  189. def get_doc_dicts(self, doc_ids: np.ndarray) -> List[dict]:
  190. return [self.dataset[doc_ids[i].tolist()] for i in range(doc_ids.shape[0])]
  191. def get_top_docs(self, question_hidden_states: np.ndarray, n_docs=5) -> Tuple[np.ndarray, np.ndarray]:
  192. _, ids = self.dataset.search_batch("embeddings", question_hidden_states, n_docs)
  193. docs = [self.dataset[[i for i in indices if i >= 0]] for indices in ids]
  194. vectors = [doc["embeddings"] for doc in docs]
  195. for i in range(len(vectors)):
  196. if len(vectors[i]) < n_docs:
  197. vectors[i] = np.vstack([vectors[i], np.zeros((n_docs - len(vectors[i]), self.vector_size))])
  198. return np.array(ids), np.array(vectors) # shapes (batch_size, n_docs) and (batch_size, n_docs, d)
  199. class CanonicalHFIndex(HFIndexBase):
  200. """
  201. A wrapper around an instance of [`~datasets.Datasets`]. If `index_path` is set to `None`, we load the pre-computed
  202. index available with the [`~datasets.arrow_dataset.Dataset`], otherwise, we load the index from the indicated path
  203. on disk.
  204. Args:
  205. vector_size (`int`): the dimension of the passages embeddings used by the index
  206. dataset_name (`str`, optional, defaults to `wiki_dpr`):
  207. A dataset identifier of the indexed dataset on HuggingFace AWS bucket (list all available datasets and ids
  208. with `datasets.list_datasets()`).
  209. dataset_split (`str`, optional, defaults to `train`)
  210. Which split of the `dataset` to load.
  211. index_name (`str`, optional, defaults to `train`)
  212. The index_name of the index associated with the `dataset`. The index loaded from `index_path` will be saved
  213. under this name.
  214. index_path (`str`, optional, defaults to `None`)
  215. The path to the serialized faiss index on disk.
  216. use_dummy_dataset (`bool`, optional, defaults to `False`):
  217. If True, use the dummy configuration of the dataset for tests.
  218. """
  219. def __init__(
  220. self,
  221. vector_size: int,
  222. dataset_name: str = "wiki_dpr",
  223. dataset_split: str = "train",
  224. index_name: Optional[str] = None,
  225. index_path: Optional[str] = None,
  226. use_dummy_dataset=False,
  227. dataset_revision=None,
  228. ):
  229. if int(index_path is None) + int(index_name is None) != 1:
  230. raise ValueError("Please provide `index_name` or `index_path`.")
  231. self.dataset_name = dataset_name
  232. self.dataset_split = dataset_split
  233. self.index_name = index_name
  234. self.index_path = index_path
  235. self.use_dummy_dataset = use_dummy_dataset
  236. self.dataset_revision = dataset_revision
  237. logger.info(f"Loading passages from {self.dataset_name}")
  238. dataset = load_dataset(
  239. self.dataset_name,
  240. with_index=False,
  241. split=self.dataset_split,
  242. dummy=self.use_dummy_dataset,
  243. revision=dataset_revision,
  244. )
  245. super().__init__(vector_size, dataset, index_initialized=False)
  246. def init_index(self):
  247. if self.index_path is not None:
  248. logger.info(f"Loading index from {self.index_path}")
  249. self.dataset.load_faiss_index("embeddings", file=self.index_path)
  250. else:
  251. logger.info(f"Loading index from {self.dataset_name} with index name {self.index_name}")
  252. self.dataset = load_dataset(
  253. self.dataset_name,
  254. with_embeddings=True,
  255. with_index=True,
  256. split=self.dataset_split,
  257. index_name=self.index_name,
  258. dummy=self.use_dummy_dataset,
  259. revision=self.dataset_revision,
  260. )
  261. self.dataset.set_format("numpy", columns=["embeddings"], output_all_columns=True)
  262. self._index_initialized = True
  263. class CustomHFIndex(HFIndexBase):
  264. """
  265. A wrapper around an instance of [`~datasets.Datasets`]. The dataset and the index are both loaded from the
  266. indicated paths on disk.
  267. Args:
  268. vector_size (`int`): the dimension of the passages embeddings used by the index
  269. dataset_path (`str`):
  270. The path to the serialized dataset on disk. The dataset should have 3 columns: title (str), text (str) and
  271. embeddings (arrays of dimension vector_size)
  272. index_path (`str`)
  273. The path to the serialized faiss index on disk.
  274. """
  275. def __init__(self, vector_size: int, dataset, index_path=None):
  276. super().__init__(vector_size, dataset, index_initialized=index_path is None)
  277. self.index_path = index_path
  278. @classmethod
  279. def load_from_disk(cls, vector_size, dataset_path, index_path):
  280. logger.info(f"Loading passages from {dataset_path}")
  281. if dataset_path is None or index_path is None:
  282. raise ValueError(
  283. "Please provide `dataset_path` and `index_path` after calling `dataset.save_to_disk(dataset_path)` "
  284. "and `dataset.get_index('embeddings').save(index_path)`."
  285. )
  286. dataset = load_from_disk(dataset_path)
  287. return cls(vector_size=vector_size, dataset=dataset, index_path=index_path)
  288. def init_index(self):
  289. if not self.is_initialized():
  290. logger.info(f"Loading index from {self.index_path}")
  291. self.dataset.load_faiss_index("embeddings", file=self.index_path)
  292. self._index_initialized = True
  293. class RagRetriever:
  294. """
  295. Retriever used to get documents from vector queries. It retrieves the documents embeddings as well as the documents
  296. contents, and it formats them to be used with a RagModel.
  297. Args:
  298. config ([`RagConfig`]):
  299. The configuration of the RAG model this Retriever is used with. Contains parameters indicating which
  300. `Index` to build. You can load your own custom dataset with `config.index_name="custom"` or use a canonical
  301. one (default) from the datasets library with `config.index_name="wiki_dpr"` for example.
  302. question_encoder_tokenizer ([`PreTrainedTokenizer`]):
  303. The tokenizer that was used to tokenize the question. It is used to decode the question and then use the
  304. generator_tokenizer.
  305. generator_tokenizer ([`PreTrainedTokenizer`]):
  306. The tokenizer used for the generator part of the RagModel.
  307. index ([`~models.rag.retrieval_rag.Index`], optional, defaults to the one defined by the configuration):
  308. If specified, use this index instead of the one built using the configuration
  309. Examples:
  310. ```python
  311. >>> # To load the default "wiki_dpr" dataset with 21M passages from wikipedia (index name is 'compressed' or 'exact')
  312. >>> from transformers import RagRetriever
  313. >>> retriever = RagRetriever.from_pretrained(
  314. ... "facebook/dpr-ctx_encoder-single-nq-base", dataset="wiki_dpr", index_name="compressed"
  315. ... )
  316. >>> # To load your own indexed dataset built with the datasets library. More info on how to build the indexed dataset in examples/rag/use_own_knowledge_dataset.py
  317. >>> from transformers import RagRetriever
  318. >>> dataset = (
  319. ... ...
  320. ... ) # dataset must be a datasets.Datasets object with columns "title", "text" and "embeddings", and it must have a faiss index
  321. >>> retriever = RagRetriever.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base", indexed_dataset=dataset)
  322. >>> # To load your own indexed dataset built with the datasets library that was saved on disk. More info in examples/rag/use_own_knowledge_dataset.py
  323. >>> from transformers import RagRetriever
  324. >>> dataset_path = "path/to/my/dataset" # dataset saved via *dataset.save_to_disk(...)*
  325. >>> index_path = "path/to/my/index.faiss" # faiss index saved via *dataset.get_index("embeddings").save(...)*
  326. >>> retriever = RagRetriever.from_pretrained(
  327. ... "facebook/dpr-ctx_encoder-single-nq-base",
  328. ... index_name="custom",
  329. ... passages_path=dataset_path,
  330. ... index_path=index_path,
  331. ... )
  332. >>> # To load the legacy index built originally for Rag's paper
  333. >>> from transformers import RagRetriever
  334. >>> retriever = RagRetriever.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base", index_name="legacy")
  335. ```"""
  336. def __init__(self, config, question_encoder_tokenizer, generator_tokenizer, index=None, init_retrieval=True):
  337. self._init_retrieval = init_retrieval
  338. requires_backends(self, ["datasets", "faiss"])
  339. super().__init__()
  340. self.index = index or self._build_index(config)
  341. self.generator_tokenizer = generator_tokenizer
  342. self.question_encoder_tokenizer = question_encoder_tokenizer
  343. self.n_docs = config.n_docs
  344. self.batch_size = config.retrieval_batch_size
  345. self.config = config
  346. if self._init_retrieval:
  347. self.init_retrieval()
  348. self.ctx_encoder_tokenizer = None
  349. self.return_tokenized_docs = False
  350. @staticmethod
  351. def _build_index(config):
  352. if config.index_name == "legacy":
  353. return LegacyIndex(
  354. config.retrieval_vector_size,
  355. config.index_path or LEGACY_INDEX_PATH,
  356. )
  357. elif config.index_name == "custom":
  358. return CustomHFIndex.load_from_disk(
  359. vector_size=config.retrieval_vector_size,
  360. dataset_path=config.passages_path,
  361. index_path=config.index_path,
  362. )
  363. else:
  364. return CanonicalHFIndex(
  365. vector_size=config.retrieval_vector_size,
  366. dataset_name=config.dataset,
  367. dataset_split=config.dataset_split,
  368. index_name=config.index_name,
  369. index_path=config.index_path,
  370. use_dummy_dataset=config.use_dummy_dataset,
  371. dataset_revision=config.dataset_revision,
  372. )
  373. @classmethod
  374. def from_pretrained(cls, retriever_name_or_path, indexed_dataset=None, **kwargs):
  375. requires_backends(cls, ["datasets", "faiss"])
  376. config = kwargs.pop("config", None) or RagConfig.from_pretrained(retriever_name_or_path, **kwargs)
  377. rag_tokenizer = RagTokenizer.from_pretrained(retriever_name_or_path, config=config)
  378. question_encoder_tokenizer = rag_tokenizer.question_encoder
  379. generator_tokenizer = rag_tokenizer.generator
  380. if indexed_dataset is not None:
  381. config.index_name = "custom"
  382. index = CustomHFIndex(config.retrieval_vector_size, indexed_dataset)
  383. else:
  384. index = cls._build_index(config)
  385. return cls(
  386. config,
  387. question_encoder_tokenizer=question_encoder_tokenizer,
  388. generator_tokenizer=generator_tokenizer,
  389. index=index,
  390. )
  391. def save_pretrained(self, save_directory):
  392. if isinstance(self.index, CustomHFIndex):
  393. if self.config.index_path is None:
  394. index_path = os.path.join(save_directory, "hf_dataset_index.faiss")
  395. self.index.dataset.get_index("embeddings").save(index_path)
  396. self.config.index_path = index_path
  397. if self.config.passages_path is None:
  398. passages_path = os.path.join(save_directory, "hf_dataset")
  399. # datasets don't support save_to_disk with indexes right now
  400. faiss_index = self.index.dataset._indexes.pop("embeddings")
  401. self.index.dataset.save_to_disk(passages_path)
  402. self.index.dataset._indexes["embeddings"] = faiss_index
  403. self.config.passages_path = passages_path
  404. self.config.save_pretrained(save_directory)
  405. rag_tokenizer = RagTokenizer(
  406. question_encoder=self.question_encoder_tokenizer,
  407. generator=self.generator_tokenizer,
  408. )
  409. rag_tokenizer.save_pretrained(save_directory)
  410. def init_retrieval(self):
  411. """
  412. Retriever initialization function. It loads the index into memory.
  413. """
  414. logger.info("initializing retrieval")
  415. self.index.init_index()
  416. def postprocess_docs(self, docs, input_strings, prefix, n_docs, return_tensors=None):
  417. r"""
  418. Postprocessing retrieved `docs` and combining them with `input_strings`.
  419. Args:
  420. docs (`dict`):
  421. Retrieved documents.
  422. input_strings (`str`):
  423. Input strings decoded by `preprocess_query`.
  424. prefix (`str`):
  425. Prefix added at the beginning of each input, typically used with T5-based models.
  426. Return:
  427. `tuple(tensors)`: a tuple consisting of two elements: contextualized `input_ids` and a compatible
  428. `attention_mask`.
  429. """
  430. def cat_input_and_doc(doc_title, doc_text, input_string, prefix):
  431. # TODO(Patrick): if we train more RAG models, I want to put the input first to take advantage of effortless truncation
  432. # TODO(piktus): better handling of truncation
  433. if doc_title.startswith('"'):
  434. doc_title = doc_title[1:]
  435. if doc_title.endswith('"'):
  436. doc_title = doc_title[:-1]
  437. if prefix is None:
  438. prefix = ""
  439. out = (prefix + doc_title + self.config.title_sep + doc_text + self.config.doc_sep + input_string).replace(
  440. " ", " "
  441. )
  442. return out
  443. rag_input_strings = [
  444. cat_input_and_doc(
  445. docs[i]["title"][j],
  446. docs[i]["text"][j],
  447. input_strings[i],
  448. prefix,
  449. )
  450. for i in range(len(docs))
  451. for j in range(n_docs)
  452. ]
  453. contextualized_inputs = self.generator_tokenizer.batch_encode_plus(
  454. rag_input_strings,
  455. max_length=self.config.max_combined_length,
  456. return_tensors=return_tensors,
  457. padding="max_length",
  458. truncation=True,
  459. )
  460. return contextualized_inputs["input_ids"], contextualized_inputs["attention_mask"]
  461. def _chunk_tensor(self, t: Iterable, chunk_size: int) -> List[Iterable]:
  462. return [t[i : i + chunk_size] for i in range(0, len(t), chunk_size)]
  463. def _main_retrieve(self, question_hidden_states: np.ndarray, n_docs: int) -> Tuple[np.ndarray, np.ndarray]:
  464. question_hidden_states_batched = self._chunk_tensor(question_hidden_states, self.batch_size)
  465. ids_batched = []
  466. vectors_batched = []
  467. for question_hidden_states in question_hidden_states_batched:
  468. start_time = time.time()
  469. ids, vectors = self.index.get_top_docs(question_hidden_states, n_docs)
  470. logger.debug(
  471. f"index search time: {time.time() - start_time} sec, batch size {question_hidden_states.shape}"
  472. )
  473. ids_batched.extend(ids)
  474. vectors_batched.extend(vectors)
  475. return (
  476. np.array(ids_batched),
  477. np.array(vectors_batched),
  478. ) # shapes (batch_size, n_docs) and (batch_size, n_docs, d)
  479. def retrieve(self, question_hidden_states: np.ndarray, n_docs: int) -> Tuple[np.ndarray, List[dict]]:
  480. """
  481. Retrieves documents for specified `question_hidden_states`.
  482. Args:
  483. question_hidden_states (`np.ndarray` of shape `(batch_size, vector_size)`):
  484. A batch of query vectors to retrieve with.
  485. n_docs (`int`):
  486. The number of docs retrieved per query.
  487. Return:
  488. `Tuple[np.ndarray, np.ndarray, List[dict]]`: A tuple with the following objects:
  489. - **retrieved_doc_embeds** (`np.ndarray` of shape `(batch_size, n_docs, dim)`) -- The retrieval embeddings
  490. of the retrieved docs per query.
  491. - **doc_ids** (`np.ndarray` of shape `(batch_size, n_docs)`) -- The ids of the documents in the index
  492. - **doc_dicts** (`List[dict]`): The `retrieved_doc_embeds` examples per query.
  493. """
  494. doc_ids, retrieved_doc_embeds = self._main_retrieve(question_hidden_states, n_docs)
  495. return retrieved_doc_embeds, doc_ids, self.index.get_doc_dicts(doc_ids)
  496. def set_ctx_encoder_tokenizer(self, ctx_encoder_tokenizer: PreTrainedTokenizer):
  497. # used in end2end retriever training
  498. self.ctx_encoder_tokenizer = ctx_encoder_tokenizer
  499. self.return_tokenized_docs = True
  500. def __call__(
  501. self,
  502. question_input_ids: List[List[int]],
  503. question_hidden_states: np.ndarray,
  504. prefix=None,
  505. n_docs=None,
  506. return_tensors=None,
  507. ) -> BatchEncoding:
  508. """
  509. Retrieves documents for specified `question_hidden_states`.
  510. Args:
  511. question_input_ids (`List[List[int]]`) batch of input ids
  512. question_hidden_states (`np.ndarray` of shape `(batch_size, vector_size)`:
  513. A batch of query vectors to retrieve with.
  514. prefix (`str`, *optional*):
  515. The prefix used by the generator's tokenizer.
  516. n_docs (`int`, *optional*):
  517. The number of docs retrieved per query.
  518. return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to "pt"):
  519. If set, will return tensors instead of list of python integers. Acceptable values are:
  520. - `'tf'`: Return TensorFlow `tf.constant` objects.
  521. - `'pt'`: Return PyTorch `torch.Tensor` objects.
  522. - `'np'`: Return Numpy `np.ndarray` objects.
  523. Returns: [`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
  524. - **context_input_ids** -- List of token ids to be fed to a model.
  525. [What are input IDs?](../glossary#input-ids)
  526. - **context_attention_mask** -- List of indices specifying which tokens should be attended to by the model
  527. (when `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names`).
  528. [What are attention masks?](../glossary#attention-mask)
  529. - **retrieved_doc_embeds** -- List of embeddings of the retrieved documents
  530. - **doc_ids** -- List of ids of the retrieved documents
  531. """
  532. n_docs = n_docs if n_docs is not None else self.n_docs
  533. prefix = prefix if prefix is not None else self.config.generator.prefix
  534. retrieved_doc_embeds, doc_ids, docs = self.retrieve(question_hidden_states, n_docs)
  535. input_strings = self.question_encoder_tokenizer.batch_decode(question_input_ids, skip_special_tokens=True)
  536. context_input_ids, context_attention_mask = self.postprocess_docs(
  537. docs, input_strings, prefix, n_docs, return_tensors=return_tensors
  538. )
  539. if self.return_tokenized_docs:
  540. retrieved_doc_text = []
  541. retrieved_doc_title = []
  542. for b_idx in range(len(docs)):
  543. for doc_idx in range(n_docs):
  544. retrieved_doc_text.append(docs[b_idx]["text"][doc_idx])
  545. retrieved_doc_title.append(docs[b_idx]["title"][doc_idx])
  546. tokenized_docs = self.ctx_encoder_tokenizer(
  547. retrieved_doc_title,
  548. retrieved_doc_text,
  549. truncation=True,
  550. padding="longest",
  551. return_tensors=return_tensors,
  552. )
  553. return BatchEncoding(
  554. {
  555. "context_input_ids": context_input_ids,
  556. "context_attention_mask": context_attention_mask,
  557. "retrieved_doc_embeds": retrieved_doc_embeds,
  558. "doc_ids": doc_ids,
  559. "tokenized_doc_ids": tokenized_docs["input_ids"],
  560. "tokenized_doc_attention_mask": tokenized_docs["attention_mask"],
  561. },
  562. tensor_type=return_tensors,
  563. )
  564. else:
  565. return BatchEncoding(
  566. {
  567. "context_input_ids": context_input_ids,
  568. "context_attention_mask": context_attention_mask,
  569. "retrieved_doc_embeds": retrieved_doc_embeds,
  570. "doc_ids": doc_ids,
  571. },
  572. tensor_type=return_tensors,
  573. )