__init__.py 12 KB


  1. # Copyright 2022 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import TYPE_CHECKING
  15. from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_tf_available, is_torch_available
  16. _import_structure = {
  17. "configuration_utils": [
  18. "BaseWatermarkingConfig",
  19. "GenerationConfig",
  20. "GenerationMode",
  21. "SynthIDTextWatermarkingConfig",
  22. "WatermarkingConfig",
  23. ],
  24. "streamers": ["TextIteratorStreamer", "TextStreamer"],
  25. }
  26. try:
  27. if not is_torch_available():
  28. raise OptionalDependencyNotAvailable()
  29. except OptionalDependencyNotAvailable:
  30. pass
  31. else:
  32. _import_structure["beam_constraints"] = [
  33. "Constraint",
  34. "ConstraintListState",
  35. "DisjunctiveConstraint",
  36. "PhrasalConstraint",
  37. ]
  38. _import_structure["beam_search"] = [
  39. "BeamHypotheses",
  40. "BeamScorer",
  41. "BeamSearchScorer",
  42. "ConstrainedBeamSearchScorer",
  43. ]
  44. _import_structure["candidate_generator"] = [
  45. "AssistedCandidateGenerator",
  46. "CandidateGenerator",
  47. "PromptLookupCandidateGenerator",
  48. ]
  49. _import_structure["logits_process"] = [
  50. "AlternatingCodebooksLogitsProcessor",
  51. "ClassifierFreeGuidanceLogitsProcessor",
  52. "EncoderNoRepeatNGramLogitsProcessor",
  53. "EncoderRepetitionPenaltyLogitsProcessor",
  54. "EpsilonLogitsWarper",
  55. "EtaLogitsWarper",
  56. "ExponentialDecayLengthPenalty",
  57. "ForcedBOSTokenLogitsProcessor",
  58. "ForcedEOSTokenLogitsProcessor",
  59. "HammingDiversityLogitsProcessor",
  60. "InfNanRemoveLogitsProcessor",
  61. "LogitNormalization",
  62. "LogitsProcessor",
  63. "LogitsProcessorList",
  64. "LogitsWarper",
  65. "MinLengthLogitsProcessor",
  66. "MinNewTokensLengthLogitsProcessor",
  67. "MinPLogitsWarper",
  68. "NoBadWordsLogitsProcessor",
  69. "NoRepeatNGramLogitsProcessor",
  70. "PrefixConstrainedLogitsProcessor",
  71. "RepetitionPenaltyLogitsProcessor",
  72. "SequenceBiasLogitsProcessor",
  73. "SuppressTokensLogitsProcessor",
  74. "SuppressTokensAtBeginLogitsProcessor",
  75. "SynthIDTextWatermarkLogitsProcessor",
  76. "TemperatureLogitsWarper",
  77. "TopKLogitsWarper",
  78. "TopPLogitsWarper",
  79. "TypicalLogitsWarper",
  80. "UnbatchedClassifierFreeGuidanceLogitsProcessor",
  81. "WhisperTimeStampLogitsProcessor",
  82. "WatermarkLogitsProcessor",
  83. ]
  84. _import_structure["stopping_criteria"] = [
  85. "MaxNewTokensCriteria",
  86. "MaxLengthCriteria",
  87. "MaxTimeCriteria",
  88. "ConfidenceCriteria",
  89. "EosTokenCriteria",
  90. "StoppingCriteria",
  91. "StoppingCriteriaList",
  92. "validate_stopping_criteria",
  93. "StopStringCriteria",
  94. ]
  95. _import_structure["utils"] = [
  96. "GenerationMixin",
  97. "GreedySearchEncoderDecoderOutput",
  98. "GreedySearchDecoderOnlyOutput",
  99. "SampleEncoderDecoderOutput",
  100. "SampleDecoderOnlyOutput",
  101. "BeamSearchEncoderDecoderOutput",
  102. "BeamSearchDecoderOnlyOutput",
  103. "BeamSampleEncoderDecoderOutput",
  104. "BeamSampleDecoderOnlyOutput",
  105. "ContrastiveSearchEncoderDecoderOutput",
  106. "ContrastiveSearchDecoderOnlyOutput",
  107. "GenerateBeamDecoderOnlyOutput",
  108. "GenerateBeamEncoderDecoderOutput",
  109. "GenerateDecoderOnlyOutput",
  110. "GenerateEncoderDecoderOutput",
  111. ]
  112. _import_structure["watermarking"] = [
  113. "WatermarkDetector",
  114. "WatermarkDetectorOutput",
  115. "BayesianDetectorModel",
  116. "BayesianDetectorConfig",
  117. "SynthIDTextWatermarkDetector",
  118. ]
  119. try:
  120. if not is_tf_available():
  121. raise OptionalDependencyNotAvailable()
  122. except OptionalDependencyNotAvailable:
  123. pass
  124. else:
  125. _import_structure["tf_logits_process"] = [
  126. "TFForcedBOSTokenLogitsProcessor",
  127. "TFForcedEOSTokenLogitsProcessor",
  128. "TFForceTokensLogitsProcessor",
  129. "TFLogitsProcessor",
  130. "TFLogitsProcessorList",
  131. "TFLogitsWarper",
  132. "TFMinLengthLogitsProcessor",
  133. "TFNoBadWordsLogitsProcessor",
  134. "TFNoRepeatNGramLogitsProcessor",
  135. "TFRepetitionPenaltyLogitsProcessor",
  136. "TFSuppressTokensAtBeginLogitsProcessor",
  137. "TFSuppressTokensLogitsProcessor",
  138. "TFTemperatureLogitsWarper",
  139. "TFTopKLogitsWarper",
  140. "TFTopPLogitsWarper",
  141. ]
  142. _import_structure["tf_utils"] = [
  143. "TFGenerationMixin",
  144. "TFGreedySearchDecoderOnlyOutput",
  145. "TFGreedySearchEncoderDecoderOutput",
  146. "TFSampleEncoderDecoderOutput",
  147. "TFSampleDecoderOnlyOutput",
  148. "TFBeamSearchEncoderDecoderOutput",
  149. "TFBeamSearchDecoderOnlyOutput",
  150. "TFBeamSampleEncoderDecoderOutput",
  151. "TFBeamSampleDecoderOnlyOutput",
  152. "TFContrastiveSearchEncoderDecoderOutput",
  153. "TFContrastiveSearchDecoderOnlyOutput",
  154. ]
  155. try:
  156. if not is_flax_available():
  157. raise OptionalDependencyNotAvailable()
  158. except OptionalDependencyNotAvailable:
  159. pass
  160. else:
  161. _import_structure["flax_logits_process"] = [
  162. "FlaxForcedBOSTokenLogitsProcessor",
  163. "FlaxForcedEOSTokenLogitsProcessor",
  164. "FlaxForceTokensLogitsProcessor",
  165. "FlaxLogitsProcessor",
  166. "FlaxLogitsProcessorList",
  167. "FlaxLogitsWarper",
  168. "FlaxMinLengthLogitsProcessor",
  169. "FlaxSuppressTokensAtBeginLogitsProcessor",
  170. "FlaxSuppressTokensLogitsProcessor",
  171. "FlaxTemperatureLogitsWarper",
  172. "FlaxTopKLogitsWarper",
  173. "FlaxTopPLogitsWarper",
  174. "FlaxWhisperTimeStampLogitsProcessor",
  175. "FlaxNoRepeatNGramLogitsProcessor",
  176. ]
  177. _import_structure["flax_utils"] = [
  178. "FlaxGenerationMixin",
  179. "FlaxGreedySearchOutput",
  180. "FlaxSampleOutput",
  181. "FlaxBeamSearchOutput",
  182. ]
  183. if TYPE_CHECKING:
  184. from .configuration_utils import (
  185. BaseWatermarkingConfig,
  186. GenerationConfig,
  187. GenerationMode,
  188. SynthIDTextWatermarkingConfig,
  189. WatermarkingConfig,
  190. )
  191. from .streamers import TextIteratorStreamer, TextStreamer
  192. try:
  193. if not is_torch_available():
  194. raise OptionalDependencyNotAvailable()
  195. except OptionalDependencyNotAvailable:
  196. pass
  197. else:
  198. from .beam_constraints import Constraint, ConstraintListState, DisjunctiveConstraint, PhrasalConstraint
  199. from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
  200. from .candidate_generator import AssistedCandidateGenerator, CandidateGenerator, PromptLookupCandidateGenerator
  201. from .logits_process import (
  202. AlternatingCodebooksLogitsProcessor,
  203. ClassifierFreeGuidanceLogitsProcessor,
  204. EncoderNoRepeatNGramLogitsProcessor,
  205. EncoderRepetitionPenaltyLogitsProcessor,
  206. EpsilonLogitsWarper,
  207. EtaLogitsWarper,
  208. ExponentialDecayLengthPenalty,
  209. ForcedBOSTokenLogitsProcessor,
  210. ForcedEOSTokenLogitsProcessor,
  211. HammingDiversityLogitsProcessor,
  212. InfNanRemoveLogitsProcessor,
  213. LogitNormalization,
  214. LogitsProcessor,
  215. LogitsProcessorList,
  216. LogitsWarper,
  217. MinLengthLogitsProcessor,
  218. MinNewTokensLengthLogitsProcessor,
  219. MinPLogitsWarper,
  220. NoBadWordsLogitsProcessor,
  221. NoRepeatNGramLogitsProcessor,
  222. PrefixConstrainedLogitsProcessor,
  223. RepetitionPenaltyLogitsProcessor,
  224. SequenceBiasLogitsProcessor,
  225. SuppressTokensAtBeginLogitsProcessor,
  226. SuppressTokensLogitsProcessor,
  227. SynthIDTextWatermarkLogitsProcessor,
  228. TemperatureLogitsWarper,
  229. TopKLogitsWarper,
  230. TopPLogitsWarper,
  231. TypicalLogitsWarper,
  232. UnbatchedClassifierFreeGuidanceLogitsProcessor,
  233. WatermarkLogitsProcessor,
  234. WhisperTimeStampLogitsProcessor,
  235. )
  236. from .stopping_criteria import (
  237. ConfidenceCriteria,
  238. EosTokenCriteria,
  239. MaxLengthCriteria,
  240. MaxNewTokensCriteria,
  241. MaxTimeCriteria,
  242. StoppingCriteria,
  243. StoppingCriteriaList,
  244. StopStringCriteria,
  245. validate_stopping_criteria,
  246. )
  247. from .utils import (
  248. BeamSampleDecoderOnlyOutput,
  249. BeamSampleEncoderDecoderOutput,
  250. BeamSearchDecoderOnlyOutput,
  251. BeamSearchEncoderDecoderOutput,
  252. ContrastiveSearchDecoderOnlyOutput,
  253. ContrastiveSearchEncoderDecoderOutput,
  254. GenerateBeamDecoderOnlyOutput,
  255. GenerateBeamEncoderDecoderOutput,
  256. GenerateDecoderOnlyOutput,
  257. GenerateEncoderDecoderOutput,
  258. GenerationMixin,
  259. GreedySearchDecoderOnlyOutput,
  260. GreedySearchEncoderDecoderOutput,
  261. SampleDecoderOnlyOutput,
  262. SampleEncoderDecoderOutput,
  263. )
  264. from .watermarking import (
  265. BayesianDetectorConfig,
  266. BayesianDetectorModel,
  267. SynthIDTextWatermarkDetector,
  268. WatermarkDetector,
  269. WatermarkDetectorOutput,
  270. )
  271. try:
  272. if not is_tf_available():
  273. raise OptionalDependencyNotAvailable()
  274. except OptionalDependencyNotAvailable:
  275. pass
  276. else:
  277. from .tf_logits_process import (
  278. TFForcedBOSTokenLogitsProcessor,
  279. TFForcedEOSTokenLogitsProcessor,
  280. TFForceTokensLogitsProcessor,
  281. TFLogitsProcessor,
  282. TFLogitsProcessorList,
  283. TFLogitsWarper,
  284. TFMinLengthLogitsProcessor,
  285. TFNoBadWordsLogitsProcessor,
  286. TFNoRepeatNGramLogitsProcessor,
  287. TFRepetitionPenaltyLogitsProcessor,
  288. TFSuppressTokensAtBeginLogitsProcessor,
  289. TFSuppressTokensLogitsProcessor,
  290. TFTemperatureLogitsWarper,
  291. TFTopKLogitsWarper,
  292. TFTopPLogitsWarper,
  293. )
  294. from .tf_utils import (
  295. TFBeamSampleDecoderOnlyOutput,
  296. TFBeamSampleEncoderDecoderOutput,
  297. TFBeamSearchDecoderOnlyOutput,
  298. TFBeamSearchEncoderDecoderOutput,
  299. TFContrastiveSearchDecoderOnlyOutput,
  300. TFContrastiveSearchEncoderDecoderOutput,
  301. TFGenerationMixin,
  302. TFGreedySearchDecoderOnlyOutput,
  303. TFGreedySearchEncoderDecoderOutput,
  304. TFSampleDecoderOnlyOutput,
  305. TFSampleEncoderDecoderOutput,
  306. )
  307. try:
  308. if not is_flax_available():
  309. raise OptionalDependencyNotAvailable()
  310. except OptionalDependencyNotAvailable:
  311. pass
  312. else:
  313. from .flax_logits_process import (
  314. FlaxForcedBOSTokenLogitsProcessor,
  315. FlaxForcedEOSTokenLogitsProcessor,
  316. FlaxForceTokensLogitsProcessor,
  317. FlaxLogitsProcessor,
  318. FlaxLogitsProcessorList,
  319. FlaxLogitsWarper,
  320. FlaxMinLengthLogitsProcessor,
  321. FlaxNoRepeatNGramLogitsProcessor,
  322. FlaxSuppressTokensAtBeginLogitsProcessor,
  323. FlaxSuppressTokensLogitsProcessor,
  324. FlaxTemperatureLogitsWarper,
  325. FlaxTopKLogitsWarper,
  326. FlaxTopPLogitsWarper,
  327. FlaxWhisperTimeStampLogitsProcessor,
  328. )
  329. from .flax_utils import FlaxBeamSearchOutput, FlaxGenerationMixin, FlaxGreedySearchOutput, FlaxSampleOutput
  330. else:
  331. import sys
  332. sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)