doc.py 40 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190
  1. # Copyright 2022 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Doc utilities: Utilities related to documentation
  16. """
  17. import functools
  18. import re
  19. import types
  20. def add_start_docstrings(*docstr):
  21. def docstring_decorator(fn):
  22. fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
  23. return fn
  24. return docstring_decorator
  25. def add_start_docstrings_to_model_forward(*docstr):
  26. def docstring_decorator(fn):
  27. docstring = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
  28. class_name = f"[`{fn.__qualname__.split('.')[0]}`]"
  29. intro = f" The {class_name} forward method, overrides the `__call__` special method."
  30. note = r"""
  31. <Tip>
  32. Although the recipe for forward pass needs to be defined within this function, one should call the [`Module`]
  33. instance afterwards instead of this since the former takes care of running the pre and post processing steps while
  34. the latter silently ignores them.
  35. </Tip>
  36. """
  37. fn.__doc__ = intro + note + docstring
  38. return fn
  39. return docstring_decorator
  40. def add_end_docstrings(*docstr):
  41. def docstring_decorator(fn):
  42. fn.__doc__ = (fn.__doc__ if fn.__doc__ is not None else "") + "".join(docstr)
  43. return fn
  44. return docstring_decorator
  45. PT_RETURN_INTRODUCTION = r"""
  46. Returns:
  47. [`{full_output_type}`] or `tuple(torch.FloatTensor)`: A [`{full_output_type}`] or a tuple of
  48. `torch.FloatTensor` (if `return_dict=False` is passed or when `config.return_dict=False`) comprising various
  49. elements depending on the configuration ([`{config_class}`]) and inputs.
  50. """
  51. TF_RETURN_INTRODUCTION = r"""
  52. Returns:
  53. [`{full_output_type}`] or `tuple(tf.Tensor)`: A [`{full_output_type}`] or a tuple of `tf.Tensor` (if
  54. `return_dict=False` is passed or when `config.return_dict=False`) comprising various elements depending on the
  55. configuration ([`{config_class}`]) and inputs.
  56. """
  57. def _get_indent(t):
  58. """Returns the indentation in the first line of t"""
  59. search = re.search(r"^(\s*)\S", t)
  60. return "" if search is None else search.groups()[0]
  61. def _convert_output_args_doc(output_args_doc):
  62. """Convert output_args_doc to display properly."""
  63. # Split output_arg_doc in blocks argument/description
  64. indent = _get_indent(output_args_doc)
  65. blocks = []
  66. current_block = ""
  67. for line in output_args_doc.split("\n"):
  68. # If the indent is the same as the beginning, the line is the name of new arg.
  69. if _get_indent(line) == indent:
  70. if len(current_block) > 0:
  71. blocks.append(current_block[:-1])
  72. current_block = f"{line}\n"
  73. else:
  74. # Otherwise it's part of the description of the current arg.
  75. # We need to remove 2 spaces to the indentation.
  76. current_block += f"{line[2:]}\n"
  77. blocks.append(current_block[:-1])
  78. # Format each block for proper rendering
  79. for i in range(len(blocks)):
  80. blocks[i] = re.sub(r"^(\s+)(\S+)(\s+)", r"\1- **\2**\3", blocks[i])
  81. blocks[i] = re.sub(r":\s*\n\s*(\S)", r" -- \1", blocks[i])
  82. return "\n".join(blocks)
  83. def _prepare_output_docstrings(output_type, config_class, min_indent=None):
  84. """
  85. Prepares the return part of the docstring using `output_type`.
  86. """
  87. output_docstring = output_type.__doc__
  88. # Remove the head of the docstring to keep the list of args only
  89. lines = output_docstring.split("\n")
  90. i = 0
  91. while i < len(lines) and re.search(r"^\s*(Args|Parameters):\s*$", lines[i]) is None:
  92. i += 1
  93. if i < len(lines):
  94. params_docstring = "\n".join(lines[(i + 1) :])
  95. params_docstring = _convert_output_args_doc(params_docstring)
  96. else:
  97. raise ValueError(
  98. f"No `Args` or `Parameters` section is found in the docstring of `{output_type.__name__}`. Make sure it has "
  99. "docstring and contain either `Args` or `Parameters`."
  100. )
  101. # Add the return introduction
  102. full_output_type = f"{output_type.__module__}.{output_type.__name__}"
  103. intro = TF_RETURN_INTRODUCTION if output_type.__name__.startswith("TF") else PT_RETURN_INTRODUCTION
  104. intro = intro.format(full_output_type=full_output_type, config_class=config_class)
  105. result = intro + params_docstring
  106. # Apply minimum indent if necessary
  107. if min_indent is not None:
  108. lines = result.split("\n")
  109. # Find the indent of the first nonempty line
  110. i = 0
  111. while len(lines[i]) == 0:
  112. i += 1
  113. indent = len(_get_indent(lines[i]))
  114. # If too small, add indentation to all nonempty lines
  115. if indent < min_indent:
  116. to_add = " " * (min_indent - indent)
  117. lines = [(f"{to_add}{line}" if len(line) > 0 else line) for line in lines]
  118. result = "\n".join(lines)
  119. return result
  120. FAKE_MODEL_DISCLAIMER = """
  121. <Tip warning={true}>
  122. This example uses a random model as the real ones are all very big. To get proper results, you should use
  123. {real_checkpoint} instead of {fake_checkpoint}. If you get out-of-memory when loading that checkpoint, you can try
  124. adding `device_map="auto"` in the `from_pretrained` call.
  125. </Tip>
  126. """
  127. PT_TOKEN_CLASSIFICATION_SAMPLE = r"""
  128. Example:
  129. ```python
  130. >>> from transformers import AutoTokenizer, {model_class}
  131. >>> import torch
  132. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  133. >>> model = {model_class}.from_pretrained("{checkpoint}")
  134. >>> inputs = tokenizer(
  135. ... "HuggingFace is a company based in Paris and New York", add_special_tokens=False, return_tensors="pt"
  136. ... )
  137. >>> with torch.no_grad():
  138. ... logits = model(**inputs).logits
  139. >>> predicted_token_class_ids = logits.argmax(-1)
  140. >>> # Note that tokens are classified rather then input words which means that
  141. >>> # there might be more predicted token classes than words.
  142. >>> # Multiple token classes might account for the same word
  143. >>> predicted_tokens_classes = [model.config.id2label[t.item()] for t in predicted_token_class_ids[0]]
  144. >>> predicted_tokens_classes
  145. {expected_output}
  146. >>> labels = predicted_token_class_ids
  147. >>> loss = model(**inputs, labels=labels).loss
  148. >>> round(loss.item(), 2)
  149. {expected_loss}
  150. ```
  151. """
  152. PT_QUESTION_ANSWERING_SAMPLE = r"""
  153. Example:
  154. ```python
  155. >>> from transformers import AutoTokenizer, {model_class}
  156. >>> import torch
  157. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  158. >>> model = {model_class}.from_pretrained("{checkpoint}")
  159. >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
  160. >>> inputs = tokenizer(question, text, return_tensors="pt")
  161. >>> with torch.no_grad():
  162. ... outputs = model(**inputs)
  163. >>> answer_start_index = outputs.start_logits.argmax()
  164. >>> answer_end_index = outputs.end_logits.argmax()
  165. >>> predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
  166. >>> tokenizer.decode(predict_answer_tokens, skip_special_tokens=True)
  167. {expected_output}
  168. >>> # target is "nice puppet"
  169. >>> target_start_index = torch.tensor([{qa_target_start_index}])
  170. >>> target_end_index = torch.tensor([{qa_target_end_index}])
  171. >>> outputs = model(**inputs, start_positions=target_start_index, end_positions=target_end_index)
  172. >>> loss = outputs.loss
  173. >>> round(loss.item(), 2)
  174. {expected_loss}
  175. ```
  176. """
  177. PT_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
  178. Example of single-label classification:
  179. ```python
  180. >>> import torch
  181. >>> from transformers import AutoTokenizer, {model_class}
  182. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  183. >>> model = {model_class}.from_pretrained("{checkpoint}")
  184. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  185. >>> with torch.no_grad():
  186. ... logits = model(**inputs).logits
  187. >>> predicted_class_id = logits.argmax().item()
  188. >>> model.config.id2label[predicted_class_id]
  189. {expected_output}
  190. >>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
  191. >>> num_labels = len(model.config.id2label)
  192. >>> model = {model_class}.from_pretrained("{checkpoint}", num_labels=num_labels)
  193. >>> labels = torch.tensor([1])
  194. >>> loss = model(**inputs, labels=labels).loss
  195. >>> round(loss.item(), 2)
  196. {expected_loss}
  197. ```
  198. Example of multi-label classification:
  199. ```python
  200. >>> import torch
  201. >>> from transformers import AutoTokenizer, {model_class}
  202. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  203. >>> model = {model_class}.from_pretrained("{checkpoint}", problem_type="multi_label_classification")
  204. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  205. >>> with torch.no_grad():
  206. ... logits = model(**inputs).logits
  207. >>> predicted_class_ids = torch.arange(0, logits.shape[-1])[torch.sigmoid(logits).squeeze(dim=0) > 0.5]
  208. >>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
  209. >>> num_labels = len(model.config.id2label)
  210. >>> model = {model_class}.from_pretrained(
  211. ... "{checkpoint}", num_labels=num_labels, problem_type="multi_label_classification"
  212. ... )
  213. >>> labels = torch.sum(
  214. ... torch.nn.functional.one_hot(predicted_class_ids[None, :].clone(), num_classes=num_labels), dim=1
  215. ... ).to(torch.float)
  216. >>> loss = model(**inputs, labels=labels).loss
  217. ```
  218. """
  219. PT_MASKED_LM_SAMPLE = r"""
  220. Example:
  221. ```python
  222. >>> from transformers import AutoTokenizer, {model_class}
  223. >>> import torch
  224. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  225. >>> model = {model_class}.from_pretrained("{checkpoint}")
  226. >>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="pt")
  227. >>> with torch.no_grad():
  228. ... logits = model(**inputs).logits
  229. >>> # retrieve index of {mask}
  230. >>> mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
  231. >>> predicted_token_id = logits[0, mask_token_index].argmax(axis=-1)
  232. >>> tokenizer.decode(predicted_token_id)
  233. {expected_output}
  234. >>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt")["input_ids"]
  235. >>> # mask labels of non-{mask} tokens
  236. >>> labels = torch.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100)
  237. >>> outputs = model(**inputs, labels=labels)
  238. >>> round(outputs.loss.item(), 2)
  239. {expected_loss}
  240. ```
  241. """
  242. PT_BASE_MODEL_SAMPLE = r"""
  243. Example:
  244. ```python
  245. >>> from transformers import AutoTokenizer, {model_class}
  246. >>> import torch
  247. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  248. >>> model = {model_class}.from_pretrained("{checkpoint}")
  249. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  250. >>> outputs = model(**inputs)
  251. >>> last_hidden_states = outputs.last_hidden_state
  252. ```
  253. """
  254. PT_MULTIPLE_CHOICE_SAMPLE = r"""
  255. Example:
  256. ```python
  257. >>> from transformers import AutoTokenizer, {model_class}
  258. >>> import torch
  259. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  260. >>> model = {model_class}.from_pretrained("{checkpoint}")
  261. >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
  262. >>> choice0 = "It is eaten with a fork and a knife."
  263. >>> choice1 = "It is eaten while held in the hand."
  264. >>> labels = torch.tensor(0).unsqueeze(0) # choice0 is correct (according to Wikipedia ;)), batch size 1
  265. >>> encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True)
  266. >>> outputs = model(**{{k: v.unsqueeze(0) for k, v in encoding.items()}}, labels=labels) # batch size is 1
  267. >>> # the linear classifier still needs to be trained
  268. >>> loss = outputs.loss
  269. >>> logits = outputs.logits
  270. ```
  271. """
  272. PT_CAUSAL_LM_SAMPLE = r"""
  273. Example:
  274. ```python
  275. >>> import torch
  276. >>> from transformers import AutoTokenizer, {model_class}
  277. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  278. >>> model = {model_class}.from_pretrained("{checkpoint}")
  279. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  280. >>> outputs = model(**inputs, labels=inputs["input_ids"])
  281. >>> loss = outputs.loss
  282. >>> logits = outputs.logits
  283. ```
  284. """
  285. PT_SPEECH_BASE_MODEL_SAMPLE = r"""
  286. Example:
  287. ```python
  288. >>> from transformers import AutoProcessor, {model_class}
  289. >>> import torch
  290. >>> from datasets import load_dataset
  291. >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation", trust_remote_code=True)
  292. >>> dataset = dataset.sort("id")
  293. >>> sampling_rate = dataset.features["audio"].sampling_rate
  294. >>> processor = AutoProcessor.from_pretrained("{checkpoint}")
  295. >>> model = {model_class}.from_pretrained("{checkpoint}")
  296. >>> # audio file is decoded on the fly
  297. >>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
  298. >>> with torch.no_grad():
  299. ... outputs = model(**inputs)
  300. >>> last_hidden_states = outputs.last_hidden_state
  301. >>> list(last_hidden_states.shape)
  302. {expected_output}
  303. ```
  304. """
  305. PT_SPEECH_CTC_SAMPLE = r"""
  306. Example:
  307. ```python
  308. >>> from transformers import AutoProcessor, {model_class}
  309. >>> from datasets import load_dataset
  310. >>> import torch
  311. >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation", trust_remote_code=True)
  312. >>> dataset = dataset.sort("id")
  313. >>> sampling_rate = dataset.features["audio"].sampling_rate
  314. >>> processor = AutoProcessor.from_pretrained("{checkpoint}")
  315. >>> model = {model_class}.from_pretrained("{checkpoint}")
  316. >>> # audio file is decoded on the fly
  317. >>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
  318. >>> with torch.no_grad():
  319. ... logits = model(**inputs).logits
  320. >>> predicted_ids = torch.argmax(logits, dim=-1)
  321. >>> # transcribe speech
  322. >>> transcription = processor.batch_decode(predicted_ids)
  323. >>> transcription[0]
  324. {expected_output}
  325. >>> inputs["labels"] = processor(text=dataset[0]["text"], return_tensors="pt").input_ids
  326. >>> # compute loss
  327. >>> loss = model(**inputs).loss
  328. >>> round(loss.item(), 2)
  329. {expected_loss}
  330. ```
  331. """
  332. PT_SPEECH_SEQ_CLASS_SAMPLE = r"""
  333. Example:
  334. ```python
  335. >>> from transformers import AutoFeatureExtractor, {model_class}
  336. >>> from datasets import load_dataset
  337. >>> import torch
  338. >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation", trust_remote_code=True)
  339. >>> dataset = dataset.sort("id")
  340. >>> sampling_rate = dataset.features["audio"].sampling_rate
  341. >>> feature_extractor = AutoFeatureExtractor.from_pretrained("{checkpoint}")
  342. >>> model = {model_class}.from_pretrained("{checkpoint}")
  343. >>> # audio file is decoded on the fly
  344. >>> inputs = feature_extractor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
  345. >>> with torch.no_grad():
  346. ... logits = model(**inputs).logits
  347. >>> predicted_class_ids = torch.argmax(logits, dim=-1).item()
  348. >>> predicted_label = model.config.id2label[predicted_class_ids]
  349. >>> predicted_label
  350. {expected_output}
  351. >>> # compute loss - target_label is e.g. "down"
  352. >>> target_label = model.config.id2label[0]
  353. >>> inputs["labels"] = torch.tensor([model.config.label2id[target_label]])
  354. >>> loss = model(**inputs).loss
  355. >>> round(loss.item(), 2)
  356. {expected_loss}
  357. ```
  358. """
  359. PT_SPEECH_FRAME_CLASS_SAMPLE = r"""
  360. Example:
  361. ```python
  362. >>> from transformers import AutoFeatureExtractor, {model_class}
  363. >>> from datasets import load_dataset
  364. >>> import torch
  365. >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation", trust_remote_code=True)
  366. >>> dataset = dataset.sort("id")
  367. >>> sampling_rate = dataset.features["audio"].sampling_rate
  368. >>> feature_extractor = AutoFeatureExtractor.from_pretrained("{checkpoint}")
  369. >>> model = {model_class}.from_pretrained("{checkpoint}")
  370. >>> # audio file is decoded on the fly
  371. >>> inputs = feature_extractor(dataset[0]["audio"]["array"], return_tensors="pt", sampling_rate=sampling_rate)
  372. >>> with torch.no_grad():
  373. ... logits = model(**inputs).logits
  374. >>> probabilities = torch.sigmoid(logits[0])
  375. >>> # labels is a one-hot array of shape (num_frames, num_speakers)
  376. >>> labels = (probabilities > 0.5).long()
  377. >>> labels[0].tolist()
  378. {expected_output}
  379. ```
  380. """
  381. PT_SPEECH_XVECTOR_SAMPLE = r"""
  382. Example:
  383. ```python
  384. >>> from transformers import AutoFeatureExtractor, {model_class}
  385. >>> from datasets import load_dataset
  386. >>> import torch
  387. >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation", trust_remote_code=True)
  388. >>> dataset = dataset.sort("id")
  389. >>> sampling_rate = dataset.features["audio"].sampling_rate
  390. >>> feature_extractor = AutoFeatureExtractor.from_pretrained("{checkpoint}")
  391. >>> model = {model_class}.from_pretrained("{checkpoint}")
  392. >>> # audio file is decoded on the fly
  393. >>> inputs = feature_extractor(
  394. ... [d["array"] for d in dataset[:2]["audio"]], sampling_rate=sampling_rate, return_tensors="pt", padding=True
  395. ... )
  396. >>> with torch.no_grad():
  397. ... embeddings = model(**inputs).embeddings
  398. >>> embeddings = torch.nn.functional.normalize(embeddings, dim=-1).cpu()
  399. >>> # the resulting embeddings can be used for cosine similarity-based retrieval
  400. >>> cosine_sim = torch.nn.CosineSimilarity(dim=-1)
  401. >>> similarity = cosine_sim(embeddings[0], embeddings[1])
  402. >>> threshold = 0.7 # the optimal threshold is dataset-dependent
  403. >>> if similarity < threshold:
  404. ... print("Speakers are not the same!")
  405. >>> round(similarity.item(), 2)
  406. {expected_output}
  407. ```
  408. """
  409. PT_VISION_BASE_MODEL_SAMPLE = r"""
  410. Example:
  411. ```python
  412. >>> from transformers import AutoImageProcessor, {model_class}
  413. >>> import torch
  414. >>> from datasets import load_dataset
  415. >>> dataset = load_dataset("huggingface/cats-image", trust_remote_code=True)
  416. >>> image = dataset["test"]["image"][0]
  417. >>> image_processor = AutoImageProcessor.from_pretrained("{checkpoint}")
  418. >>> model = {model_class}.from_pretrained("{checkpoint}")
  419. >>> inputs = image_processor(image, return_tensors="pt")
  420. >>> with torch.no_grad():
  421. ... outputs = model(**inputs)
  422. >>> last_hidden_states = outputs.last_hidden_state
  423. >>> list(last_hidden_states.shape)
  424. {expected_output}
  425. ```
  426. """
  427. PT_VISION_SEQ_CLASS_SAMPLE = r"""
  428. Example:
  429. ```python
  430. >>> from transformers import AutoImageProcessor, {model_class}
  431. >>> import torch
  432. >>> from datasets import load_dataset
  433. >>> dataset = load_dataset("huggingface/cats-image", trust_remote_code=True)
  434. >>> image = dataset["test"]["image"][0]
  435. >>> image_processor = AutoImageProcessor.from_pretrained("{checkpoint}")
  436. >>> model = {model_class}.from_pretrained("{checkpoint}")
  437. >>> inputs = image_processor(image, return_tensors="pt")
  438. >>> with torch.no_grad():
  439. ... logits = model(**inputs).logits
  440. >>> # model predicts one of the 1000 ImageNet classes
  441. >>> predicted_label = logits.argmax(-1).item()
  442. >>> print(model.config.id2label[predicted_label])
  443. {expected_output}
  444. ```
  445. """
  446. PT_SAMPLE_DOCSTRINGS = {
  447. "SequenceClassification": PT_SEQUENCE_CLASSIFICATION_SAMPLE,
  448. "QuestionAnswering": PT_QUESTION_ANSWERING_SAMPLE,
  449. "TokenClassification": PT_TOKEN_CLASSIFICATION_SAMPLE,
  450. "MultipleChoice": PT_MULTIPLE_CHOICE_SAMPLE,
  451. "MaskedLM": PT_MASKED_LM_SAMPLE,
  452. "LMHead": PT_CAUSAL_LM_SAMPLE,
  453. "BaseModel": PT_BASE_MODEL_SAMPLE,
  454. "SpeechBaseModel": PT_SPEECH_BASE_MODEL_SAMPLE,
  455. "CTC": PT_SPEECH_CTC_SAMPLE,
  456. "AudioClassification": PT_SPEECH_SEQ_CLASS_SAMPLE,
  457. "AudioFrameClassification": PT_SPEECH_FRAME_CLASS_SAMPLE,
  458. "AudioXVector": PT_SPEECH_XVECTOR_SAMPLE,
  459. "VisionBaseModel": PT_VISION_BASE_MODEL_SAMPLE,
  460. "ImageClassification": PT_VISION_SEQ_CLASS_SAMPLE,
  461. }
  462. TF_TOKEN_CLASSIFICATION_SAMPLE = r"""
  463. Example:
  464. ```python
  465. >>> from transformers import AutoTokenizer, {model_class}
  466. >>> import tensorflow as tf
  467. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  468. >>> model = {model_class}.from_pretrained("{checkpoint}")
  469. >>> inputs = tokenizer(
  470. ... "HuggingFace is a company based in Paris and New York", add_special_tokens=False, return_tensors="tf"
  471. ... )
  472. >>> logits = model(**inputs).logits
  473. >>> predicted_token_class_ids = tf.math.argmax(logits, axis=-1)
  474. >>> # Note that tokens are classified rather then input words which means that
  475. >>> # there might be more predicted token classes than words.
  476. >>> # Multiple token classes might account for the same word
  477. >>> predicted_tokens_classes = [model.config.id2label[t] for t in predicted_token_class_ids[0].numpy().tolist()]
  478. >>> predicted_tokens_classes
  479. {expected_output}
  480. ```
  481. ```python
  482. >>> labels = predicted_token_class_ids
  483. >>> loss = tf.math.reduce_mean(model(**inputs, labels=labels).loss)
  484. >>> round(float(loss), 2)
  485. {expected_loss}
  486. ```
  487. """
  488. TF_QUESTION_ANSWERING_SAMPLE = r"""
  489. Example:
  490. ```python
  491. >>> from transformers import AutoTokenizer, {model_class}
  492. >>> import tensorflow as tf
  493. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  494. >>> model = {model_class}.from_pretrained("{checkpoint}")
  495. >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
  496. >>> inputs = tokenizer(question, text, return_tensors="tf")
  497. >>> outputs = model(**inputs)
  498. >>> answer_start_index = int(tf.math.argmax(outputs.start_logits, axis=-1)[0])
  499. >>> answer_end_index = int(tf.math.argmax(outputs.end_logits, axis=-1)[0])
  500. >>> predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
  501. >>> tokenizer.decode(predict_answer_tokens)
  502. {expected_output}
  503. ```
  504. ```python
  505. >>> # target is "nice puppet"
  506. >>> target_start_index = tf.constant([{qa_target_start_index}])
  507. >>> target_end_index = tf.constant([{qa_target_end_index}])
  508. >>> outputs = model(**inputs, start_positions=target_start_index, end_positions=target_end_index)
  509. >>> loss = tf.math.reduce_mean(outputs.loss)
  510. >>> round(float(loss), 2)
  511. {expected_loss}
  512. ```
  513. """
  514. TF_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
  515. Example:
  516. ```python
  517. >>> from transformers import AutoTokenizer, {model_class}
  518. >>> import tensorflow as tf
  519. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  520. >>> model = {model_class}.from_pretrained("{checkpoint}")
  521. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
  522. >>> logits = model(**inputs).logits
  523. >>> predicted_class_id = int(tf.math.argmax(logits, axis=-1)[0])
  524. >>> model.config.id2label[predicted_class_id]
  525. {expected_output}
  526. ```
  527. ```python
  528. >>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
  529. >>> num_labels = len(model.config.id2label)
  530. >>> model = {model_class}.from_pretrained("{checkpoint}", num_labels=num_labels)
  531. >>> labels = tf.constant(1)
  532. >>> loss = model(**inputs, labels=labels).loss
  533. >>> round(float(loss), 2)
  534. {expected_loss}
  535. ```
  536. """
  537. TF_MASKED_LM_SAMPLE = r"""
  538. Example:
  539. ```python
  540. >>> from transformers import AutoTokenizer, {model_class}
  541. >>> import tensorflow as tf
  542. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  543. >>> model = {model_class}.from_pretrained("{checkpoint}")
  544. >>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="tf")
  545. >>> logits = model(**inputs).logits
  546. >>> # retrieve index of {mask}
  547. >>> mask_token_index = tf.where((inputs.input_ids == tokenizer.mask_token_id)[0])
  548. >>> selected_logits = tf.gather_nd(logits[0], indices=mask_token_index)
  549. >>> predicted_token_id = tf.math.argmax(selected_logits, axis=-1)
  550. >>> tokenizer.decode(predicted_token_id)
  551. {expected_output}
  552. ```
  553. ```python
  554. >>> labels = tokenizer("The capital of France is Paris.", return_tensors="tf")["input_ids"]
  555. >>> # mask labels of non-{mask} tokens
  556. >>> labels = tf.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100)
  557. >>> outputs = model(**inputs, labels=labels)
  558. >>> round(float(outputs.loss), 2)
  559. {expected_loss}
  560. ```
  561. """
  562. TF_BASE_MODEL_SAMPLE = r"""
  563. Example:
  564. ```python
  565. >>> from transformers import AutoTokenizer, {model_class}
  566. >>> import tensorflow as tf
  567. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  568. >>> model = {model_class}.from_pretrained("{checkpoint}")
  569. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
  570. >>> outputs = model(inputs)
  571. >>> last_hidden_states = outputs.last_hidden_state
  572. ```
  573. """
  574. TF_MULTIPLE_CHOICE_SAMPLE = r"""
  575. Example:
  576. ```python
  577. >>> from transformers import AutoTokenizer, {model_class}
  578. >>> import tensorflow as tf
  579. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  580. >>> model = {model_class}.from_pretrained("{checkpoint}")
  581. >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
  582. >>> choice0 = "It is eaten with a fork and a knife."
  583. >>> choice1 = "It is eaten while held in the hand."
  584. >>> encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="tf", padding=True)
  585. >>> inputs = {{k: tf.expand_dims(v, 0) for k, v in encoding.items()}}
  586. >>> outputs = model(inputs) # batch size is 1
  587. >>> # the linear classifier still needs to be trained
  588. >>> logits = outputs.logits
  589. ```
  590. """
  591. TF_CAUSAL_LM_SAMPLE = r"""
  592. Example:
  593. ```python
  594. >>> from transformers import AutoTokenizer, {model_class}
  595. >>> import tensorflow as tf
  596. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  597. >>> model = {model_class}.from_pretrained("{checkpoint}")
  598. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
  599. >>> outputs = model(inputs)
  600. >>> logits = outputs.logits
  601. ```
  602. """
  603. TF_SPEECH_BASE_MODEL_SAMPLE = r"""
  604. Example:
  605. ```python
  606. >>> from transformers import AutoProcessor, {model_class}
  607. >>> from datasets import load_dataset
  608. >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation", trust_remote_code=True)
  609. >>> dataset = dataset.sort("id")
  610. >>> sampling_rate = dataset.features["audio"].sampling_rate
  611. >>> processor = AutoProcessor.from_pretrained("{checkpoint}")
  612. >>> model = {model_class}.from_pretrained("{checkpoint}")
  613. >>> # audio file is decoded on the fly
  614. >>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="tf")
  615. >>> outputs = model(**inputs)
  616. >>> last_hidden_states = outputs.last_hidden_state
  617. >>> list(last_hidden_states.shape)
  618. {expected_output}
  619. ```
  620. """
  621. TF_SPEECH_CTC_SAMPLE = r"""
  622. Example:
  623. ```python
  624. >>> from transformers import AutoProcessor, {model_class}
  625. >>> from datasets import load_dataset
  626. >>> import tensorflow as tf
  627. >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation", trust_remote_code=True)
  628. >>> dataset = dataset.sort("id")
  629. >>> sampling_rate = dataset.features["audio"].sampling_rate
  630. >>> processor = AutoProcessor.from_pretrained("{checkpoint}")
  631. >>> model = {model_class}.from_pretrained("{checkpoint}")
  632. >>> # audio file is decoded on the fly
  633. >>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="tf")
  634. >>> logits = model(**inputs).logits
  635. >>> predicted_ids = tf.math.argmax(logits, axis=-1)
  636. >>> # transcribe speech
  637. >>> transcription = processor.batch_decode(predicted_ids)
  638. >>> transcription[0]
  639. {expected_output}
  640. ```
  641. ```python
  642. >>> inputs["labels"] = processor(text=dataset[0]["text"], return_tensors="tf").input_ids
  643. >>> # compute loss
  644. >>> loss = model(**inputs).loss
  645. >>> round(float(loss), 2)
  646. {expected_loss}
  647. ```
  648. """
  649. TF_VISION_BASE_MODEL_SAMPLE = r"""
  650. Example:
  651. ```python
  652. >>> from transformers import AutoImageProcessor, {model_class}
  653. >>> from datasets import load_dataset
  654. >>> dataset = load_dataset("huggingface/cats-image", trust_remote_code=True)
  655. >>> image = dataset["test"]["image"][0]
  656. >>> image_processor = AutoImageProcessor.from_pretrained("{checkpoint}")
  657. >>> model = {model_class}.from_pretrained("{checkpoint}")
  658. >>> inputs = image_processor(image, return_tensors="tf")
  659. >>> outputs = model(**inputs)
  660. >>> last_hidden_states = outputs.last_hidden_state
  661. >>> list(last_hidden_states.shape)
  662. {expected_output}
  663. ```
  664. """
  665. TF_VISION_SEQ_CLASS_SAMPLE = r"""
  666. Example:
  667. ```python
  668. >>> from transformers import AutoImageProcessor, {model_class}
  669. >>> import tensorflow as tf
  670. >>> from datasets import load_dataset
  671. >>> dataset = load_dataset("huggingface/cats-image", trust_remote_code=True)
  672. >>> image = dataset["test"]["image"][0]
  673. >>> image_processor = AutoImageProcessor.from_pretrained("{checkpoint}")
  674. >>> model = {model_class}.from_pretrained("{checkpoint}")
  675. >>> inputs = image_processor(image, return_tensors="tf")
  676. >>> logits = model(**inputs).logits
  677. >>> # model predicts one of the 1000 ImageNet classes
  678. >>> predicted_label = int(tf.math.argmax(logits, axis=-1))
  679. >>> print(model.config.id2label[predicted_label])
  680. {expected_output}
  681. ```
  682. """
  683. TF_SAMPLE_DOCSTRINGS = {
  684. "SequenceClassification": TF_SEQUENCE_CLASSIFICATION_SAMPLE,
  685. "QuestionAnswering": TF_QUESTION_ANSWERING_SAMPLE,
  686. "TokenClassification": TF_TOKEN_CLASSIFICATION_SAMPLE,
  687. "MultipleChoice": TF_MULTIPLE_CHOICE_SAMPLE,
  688. "MaskedLM": TF_MASKED_LM_SAMPLE,
  689. "LMHead": TF_CAUSAL_LM_SAMPLE,
  690. "BaseModel": TF_BASE_MODEL_SAMPLE,
  691. "SpeechBaseModel": TF_SPEECH_BASE_MODEL_SAMPLE,
  692. "CTC": TF_SPEECH_CTC_SAMPLE,
  693. "VisionBaseModel": TF_VISION_BASE_MODEL_SAMPLE,
  694. "ImageClassification": TF_VISION_SEQ_CLASS_SAMPLE,
  695. }
  696. FLAX_TOKEN_CLASSIFICATION_SAMPLE = r"""
  697. Example:
  698. ```python
  699. >>> from transformers import AutoTokenizer, {model_class}
  700. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  701. >>> model = {model_class}.from_pretrained("{checkpoint}")
  702. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
  703. >>> outputs = model(**inputs)
  704. >>> logits = outputs.logits
  705. ```
  706. """
  707. FLAX_QUESTION_ANSWERING_SAMPLE = r"""
  708. Example:
  709. ```python
  710. >>> from transformers import AutoTokenizer, {model_class}
  711. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  712. >>> model = {model_class}.from_pretrained("{checkpoint}")
  713. >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
  714. >>> inputs = tokenizer(question, text, return_tensors="jax")
  715. >>> outputs = model(**inputs)
  716. >>> start_scores = outputs.start_logits
  717. >>> end_scores = outputs.end_logits
  718. ```
  719. """
  720. FLAX_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
  721. Example:
  722. ```python
  723. >>> from transformers import AutoTokenizer, {model_class}
  724. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  725. >>> model = {model_class}.from_pretrained("{checkpoint}")
  726. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
  727. >>> outputs = model(**inputs)
  728. >>> logits = outputs.logits
  729. ```
  730. """
  731. FLAX_MASKED_LM_SAMPLE = r"""
  732. Example:
  733. ```python
  734. >>> from transformers import AutoTokenizer, {model_class}
  735. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  736. >>> model = {model_class}.from_pretrained("{checkpoint}")
  737. >>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="jax")
  738. >>> outputs = model(**inputs)
  739. >>> logits = outputs.logits
  740. ```
  741. """
  742. FLAX_BASE_MODEL_SAMPLE = r"""
  743. Example:
  744. ```python
  745. >>> from transformers import AutoTokenizer, {model_class}
  746. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  747. >>> model = {model_class}.from_pretrained("{checkpoint}")
  748. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
  749. >>> outputs = model(**inputs)
  750. >>> last_hidden_states = outputs.last_hidden_state
  751. ```
  752. """
  753. FLAX_MULTIPLE_CHOICE_SAMPLE = r"""
  754. Example:
  755. ```python
  756. >>> from transformers import AutoTokenizer, {model_class}
  757. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  758. >>> model = {model_class}.from_pretrained("{checkpoint}")
  759. >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
  760. >>> choice0 = "It is eaten with a fork and a knife."
  761. >>> choice1 = "It is eaten while held in the hand."
  762. >>> encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="jax", padding=True)
  763. >>> outputs = model(**{{k: v[None, :] for k, v in encoding.items()}})
  764. >>> logits = outputs.logits
  765. ```
  766. """
  767. FLAX_CAUSAL_LM_SAMPLE = r"""
  768. Example:
  769. ```python
  770. >>> from transformers import AutoTokenizer, {model_class}
  771. >>> tokenizer = AutoTokenizer.from_pretrained("{checkpoint}")
  772. >>> model = {model_class}.from_pretrained("{checkpoint}")
  773. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
  774. >>> outputs = model(**inputs)
  775. >>> # retrieve logts for next token
  776. >>> next_token_logits = outputs.logits[:, -1]
  777. ```
  778. """
  779. FLAX_SAMPLE_DOCSTRINGS = {
  780. "SequenceClassification": FLAX_SEQUENCE_CLASSIFICATION_SAMPLE,
  781. "QuestionAnswering": FLAX_QUESTION_ANSWERING_SAMPLE,
  782. "TokenClassification": FLAX_TOKEN_CLASSIFICATION_SAMPLE,
  783. "MultipleChoice": FLAX_MULTIPLE_CHOICE_SAMPLE,
  784. "MaskedLM": FLAX_MASKED_LM_SAMPLE,
  785. "BaseModel": FLAX_BASE_MODEL_SAMPLE,
  786. "LMHead": FLAX_CAUSAL_LM_SAMPLE,
  787. }
  788. def filter_outputs_from_example(docstring, **kwargs):
  789. """
  790. Removes the lines testing an output with the doctest syntax in a code sample when it's set to `None`.
  791. """
  792. for key, value in kwargs.items():
  793. if value is not None:
  794. continue
  795. doc_key = "{" + key + "}"
  796. docstring = re.sub(rf"\n([^\n]+)\n\s+{doc_key}\n", "\n", docstring)
  797. return docstring
  798. def add_code_sample_docstrings(
  799. *docstr,
  800. processor_class=None,
  801. checkpoint=None,
  802. output_type=None,
  803. config_class=None,
  804. mask="[MASK]",
  805. qa_target_start_index=14,
  806. qa_target_end_index=15,
  807. model_cls=None,
  808. modality=None,
  809. expected_output=None,
  810. expected_loss=None,
  811. real_checkpoint=None,
  812. revision=None,
  813. ):
  814. def docstring_decorator(fn):
  815. # model_class defaults to function's class if not specified otherwise
  816. model_class = fn.__qualname__.split(".")[0] if model_cls is None else model_cls
  817. if model_class[:2] == "TF":
  818. sample_docstrings = TF_SAMPLE_DOCSTRINGS
  819. elif model_class[:4] == "Flax":
  820. sample_docstrings = FLAX_SAMPLE_DOCSTRINGS
  821. else:
  822. sample_docstrings = PT_SAMPLE_DOCSTRINGS
  823. # putting all kwargs for docstrings in a dict to be used
  824. # with the `.format(**doc_kwargs)`. Note that string might
  825. # be formatted with non-existing keys, which is fine.
  826. doc_kwargs = {
  827. "model_class": model_class,
  828. "processor_class": processor_class,
  829. "checkpoint": checkpoint,
  830. "mask": mask,
  831. "qa_target_start_index": qa_target_start_index,
  832. "qa_target_end_index": qa_target_end_index,
  833. "expected_output": expected_output,
  834. "expected_loss": expected_loss,
  835. "real_checkpoint": real_checkpoint,
  836. "fake_checkpoint": checkpoint,
  837. "true": "{true}", # For <Tip warning={true}> syntax that conflicts with formatting.
  838. }
  839. if ("SequenceClassification" in model_class or "AudioClassification" in model_class) and modality == "audio":
  840. code_sample = sample_docstrings["AudioClassification"]
  841. elif "SequenceClassification" in model_class:
  842. code_sample = sample_docstrings["SequenceClassification"]
  843. elif "QuestionAnswering" in model_class:
  844. code_sample = sample_docstrings["QuestionAnswering"]
  845. elif "TokenClassification" in model_class:
  846. code_sample = sample_docstrings["TokenClassification"]
  847. elif "MultipleChoice" in model_class:
  848. code_sample = sample_docstrings["MultipleChoice"]
  849. elif "MaskedLM" in model_class or model_class in ["FlaubertWithLMHeadModel", "XLMWithLMHeadModel"]:
  850. code_sample = sample_docstrings["MaskedLM"]
  851. elif "LMHead" in model_class or "CausalLM" in model_class:
  852. code_sample = sample_docstrings["LMHead"]
  853. elif "CTC" in model_class:
  854. code_sample = sample_docstrings["CTC"]
  855. elif "AudioFrameClassification" in model_class:
  856. code_sample = sample_docstrings["AudioFrameClassification"]
  857. elif "XVector" in model_class and modality == "audio":
  858. code_sample = sample_docstrings["AudioXVector"]
  859. elif "Model" in model_class and modality == "audio":
  860. code_sample = sample_docstrings["SpeechBaseModel"]
  861. elif "Model" in model_class and modality == "vision":
  862. code_sample = sample_docstrings["VisionBaseModel"]
  863. elif "Model" in model_class or "Encoder" in model_class:
  864. code_sample = sample_docstrings["BaseModel"]
  865. elif "ImageClassification" in model_class:
  866. code_sample = sample_docstrings["ImageClassification"]
  867. else:
  868. raise ValueError(f"Docstring can't be built for model {model_class}")
  869. code_sample = filter_outputs_from_example(
  870. code_sample, expected_output=expected_output, expected_loss=expected_loss
  871. )
  872. if real_checkpoint is not None:
  873. code_sample = FAKE_MODEL_DISCLAIMER + code_sample
  874. func_doc = (fn.__doc__ or "") + "".join(docstr)
  875. output_doc = "" if output_type is None else _prepare_output_docstrings(output_type, config_class)
  876. built_doc = code_sample.format(**doc_kwargs)
  877. if revision is not None:
  878. if re.match(r"^refs/pr/\\d+", revision):
  879. raise ValueError(
  880. f"The provided revision '{revision}' is incorrect. It should point to"
  881. " a pull request reference on the hub like 'refs/pr/6'"
  882. )
  883. built_doc = built_doc.replace(
  884. f'from_pretrained("{checkpoint}")', f'from_pretrained("{checkpoint}", revision="{revision}")'
  885. )
  886. fn.__doc__ = func_doc + output_doc + built_doc
  887. return fn
  888. return docstring_decorator
  889. def replace_return_docstrings(output_type=None, config_class=None):
  890. def docstring_decorator(fn):
  891. func_doc = fn.__doc__
  892. lines = func_doc.split("\n")
  893. i = 0
  894. while i < len(lines) and re.search(r"^\s*Returns?:\s*$", lines[i]) is None:
  895. i += 1
  896. if i < len(lines):
  897. indent = len(_get_indent(lines[i]))
  898. lines[i] = _prepare_output_docstrings(output_type, config_class, min_indent=indent)
  899. func_doc = "\n".join(lines)
  900. else:
  901. raise ValueError(
  902. f"The function {fn} should have an empty 'Return:' or 'Returns:' in its docstring as placeholder, "
  903. f"current docstring is:\n{func_doc}"
  904. )
  905. fn.__doc__ = func_doc
  906. return fn
  907. return docstring_decorator
  908. def copy_func(f):
  909. """Returns a copy of a function f."""
  910. # Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard)
  911. g = types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__, closure=f.__closure__)
  912. g = functools.update_wrapper(g, f)
  913. g.__kwdefaults__ = f.__kwdefaults__
  914. return g