modeling_imagegpt.py 51 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165
  1. # coding=utf-8
  2. # Copyright 2021 The OpenAI Team Authors and HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch OpenAI ImageGPT model."""
  16. import math
  17. import os
  18. import warnings
  19. from typing import Any, Optional, Tuple, Union
  20. import torch
  21. import torch.utils.checkpoint
  22. from torch import nn
  23. from torch.cuda.amp import autocast
  24. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  25. from ...activations import ACT2FN
  26. from ...generation import GenerationMixin
  27. from ...modeling_outputs import (
  28. BaseModelOutputWithPastAndCrossAttentions,
  29. CausalLMOutputWithCrossAttentions,
  30. SequenceClassifierOutputWithPast,
  31. )
  32. from ...modeling_utils import PreTrainedModel
  33. from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
  34. from ...utils import (
  35. add_start_docstrings,
  36. add_start_docstrings_to_model_forward,
  37. logging,
  38. replace_return_docstrings,
  39. torch_float,
  40. )
  41. from .configuration_imagegpt import ImageGPTConfig
  42. logger = logging.get_logger(__name__)
  43. _CHECKPOINT_FOR_DOC = "openai/imagegpt-small"
  44. _CONFIG_FOR_DOC = "ImageGPTConfig"
  45. def load_tf_weights_in_imagegpt(model, config, imagegpt_checkpoint_path):
  46. """
  47. Load tf checkpoints in a pytorch model
  48. """
  49. try:
  50. import re
  51. import tensorflow as tf
  52. except ImportError:
  53. logger.error(
  54. "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
  55. "https://www.tensorflow.org/install/ for installation instructions."
  56. )
  57. raise
  58. tf_path = os.path.abspath(imagegpt_checkpoint_path)
  59. logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
  60. # Load weights from TF model
  61. init_vars = tf.train.list_variables(tf_path)
  62. names = []
  63. arrays = []
  64. for name, shape in init_vars:
  65. logger.info("Loading TF weight {} with shape {}".format(name, shape))
  66. array = tf.train.load_variable(tf_path, name)
  67. names.append(name)
  68. arrays.append(array.squeeze())
  69. for name, array in zip(names, arrays):
  70. name = name[6:] # skip "model/"
  71. name = name.split("/")
  72. # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
  73. # which are not required for using pretrained model
  74. if any(
  75. n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
  76. for n in name
  77. ) or name[-1] in ["_step"]:
  78. logger.info("Skipping {}".format("/".join(name)))
  79. continue
  80. pointer = model
  81. if name[-1] not in ["wtet"]:
  82. pointer = getattr(pointer, "transformer")
  83. for m_name in name:
  84. if re.fullmatch(r"[A-Za-z]+\d+", m_name):
  85. scope_names = re.split(r"(\d+)", m_name)
  86. else:
  87. scope_names = [m_name]
  88. if scope_names[0] == "w" or scope_names[0] == "g":
  89. pointer = getattr(pointer, "weight")
  90. elif scope_names[0] == "b":
  91. pointer = getattr(pointer, "bias")
  92. elif scope_names[0] == "wpe" or scope_names[0] == "wte":
  93. pointer = getattr(pointer, scope_names[0])
  94. pointer = getattr(pointer, "weight")
  95. elif scope_names[0] in ["q_proj", "k_proj", "v_proj"]:
  96. pointer = getattr(pointer, "c_attn")
  97. pointer = getattr(pointer, "weight")
  98. elif len(name) == 3 and name[1] == "attn" and scope_names[0] == "c_proj":
  99. pointer = getattr(pointer, scope_names[0])
  100. pointer = getattr(pointer, "weight")
  101. elif scope_names[0] == "wtet":
  102. pointer = getattr(pointer, "lm_head")
  103. pointer = getattr(pointer, "weight")
  104. elif scope_names[0] == "sos":
  105. pointer = getattr(pointer, "wte")
  106. pointer = getattr(pointer, "weight")
  107. else:
  108. pointer = getattr(pointer, scope_names[0])
  109. if len(scope_names) >= 2:
  110. num = int(scope_names[1])
  111. pointer = pointer[num]
  112. if len(name) > 1 and name[1] == "attn" or name[-1] == "wtet" or name[-1] == "sos" or name[-1] == "wte":
  113. pass # array is used to initialize only part of the pointer so sizes won't match
  114. else:
  115. try:
  116. assert pointer.shape == array.shape
  117. except AssertionError as e:
  118. e.args += (pointer.shape, array.shape)
  119. raise
  120. logger.info("Initialize PyTorch weight {}".format(name))
  121. if name[-1] == "q_proj":
  122. pointer.data[:, : config.n_embd] = torch.from_numpy(array.reshape(config.n_embd, config.n_embd)).T
  123. elif name[-1] == "k_proj":
  124. pointer.data[:, config.n_embd : 2 * config.n_embd] = torch.from_numpy(
  125. array.reshape(config.n_embd, config.n_embd)
  126. ).T
  127. elif name[-1] == "v_proj":
  128. pointer.data[:, 2 * config.n_embd :] = torch.from_numpy(array.reshape(config.n_embd, config.n_embd)).T
  129. elif len(name) == 3 and name[1] == "attn" and name[2] == "c_proj":
  130. pointer.data = torch.from_numpy(array.reshape(config.n_embd, config.n_embd))
  131. elif name[-1] == "wtet":
  132. pointer.data = torch.from_numpy(array)
  133. elif name[-1] == "wte":
  134. pointer.data[: config.vocab_size - 1, :] = torch.from_numpy(array)
  135. elif name[-1] == "sos":
  136. pointer.data[-1] = torch.from_numpy(array)
  137. else:
  138. pointer.data = torch.from_numpy(array)
  139. return model
  140. class ImageGPTLayerNorm(nn.Module):
  141. def __init__(self, hidden_size: Tuple[int], eps: float = 1e-5):
  142. super().__init__()
  143. self.eps = eps
  144. self.weight = nn.Parameter(torch.Tensor(hidden_size))
  145. def forward(self, tensor: torch.Tensor) -> tuple:
  146. # input is not mean centered
  147. return (
  148. tensor
  149. / torch.sqrt(torch.mean(torch.square(tensor), axis=-1, keepdim=True) + self.eps)
  150. * self.weight.data[..., :]
  151. )
  152. class ImageGPTAttention(nn.Module):
  153. def __init__(self, config, is_cross_attention: Optional[bool] = False, layer_idx: Optional[int] = None):
  154. super().__init__()
  155. max_positions = config.max_position_embeddings
  156. self.register_buffer(
  157. "bias",
  158. torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
  159. 1, 1, max_positions, max_positions
  160. ),
  161. persistent=False,
  162. )
  163. self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
  164. self.embed_dim = config.hidden_size
  165. self.num_heads = config.num_attention_heads
  166. self.head_dim = self.embed_dim // self.num_heads
  167. self.split_size = self.embed_dim
  168. if self.head_dim * self.num_heads != self.embed_dim:
  169. raise ValueError(
  170. f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  171. f" {self.num_heads})."
  172. )
  173. self.scale_attn_weights = config.scale_attn_weights
  174. self.is_cross_attention = is_cross_attention
  175. # Layer-wise attention scaling, reordering, and upcasting
  176. self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
  177. self.layer_idx = layer_idx
  178. self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
  179. if self.is_cross_attention:
  180. self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
  181. self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
  182. else:
  183. self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
  184. self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
  185. self.attn_dropout = nn.Dropout(config.attn_pdrop)
  186. self.resid_dropout = nn.Dropout(config.resid_pdrop)
  187. self.pruned_heads = set()
  188. def prune_heads(self, heads):
  189. if len(heads) == 0:
  190. return
  191. heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
  192. index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
  193. # Prune conv1d layers
  194. self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
  195. self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
  196. # Update hyper params
  197. self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
  198. self.num_heads = self.num_heads - len(heads)
  199. self.pruned_heads = self.pruned_heads.union(heads)
  200. def _attn(self, query, key, value, attention_mask=None, head_mask=None):
  201. attn_weights = torch.matmul(query, key.transpose(-1, -2))
  202. if self.scale_attn_weights:
  203. attn_weights = attn_weights / torch_float(value.size(-1) ** 0.5)
  204. # Layer-wise attention scaling
  205. if self.scale_attn_by_inverse_layer_idx:
  206. attn_weights = attn_weights / float(self.layer_idx + 1)
  207. if not self.is_cross_attention:
  208. # if only "normal" attention layer implements causal mask
  209. query_length, key_length = query.size(-2), key.size(-2)
  210. causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
  211. mask_value = torch.finfo(attn_weights.dtype).min
  212. # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
  213. # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
  214. mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
  215. attn_weights = torch.where(causal_mask, attn_weights, mask_value)
  216. if attention_mask is not None:
  217. # Apply the attention mask
  218. attn_weights = attn_weights + attention_mask
  219. attn_weights = nn.Softmax(dim=-1)(attn_weights)
  220. # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
  221. attn_weights = attn_weights.type(value.dtype)
  222. attn_weights = self.attn_dropout(attn_weights)
  223. # Mask heads if we want to
  224. if head_mask is not None:
  225. attn_weights = attn_weights * head_mask
  226. attn_output = torch.matmul(attn_weights, value)
  227. return attn_output, attn_weights
  228. def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
  229. # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
  230. bsz, num_heads, q_seq_len, dk = query.size()
  231. _, _, k_seq_len, _ = key.size()
  232. # Preallocate attn_weights for `baddbmm`
  233. attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
  234. # Compute Scale Factor
  235. scale_factor = 1.0
  236. if self.scale_attn_weights:
  237. scale_factor /= float(value.size(-1)) ** 0.5
  238. if self.scale_attn_by_inverse_layer_idx:
  239. scale_factor /= float(self.layer_idx + 1)
  240. # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
  241. with autocast(enabled=False):
  242. q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
  243. attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
  244. attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
  245. if not self.is_cross_attention:
  246. # if only "normal" attention layer implements causal mask
  247. query_length, key_length = query.size(-2), key.size(-2)
  248. causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
  249. mask_value = torch.finfo(attn_weights.dtype).min
  250. # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
  251. # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
  252. mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
  253. attn_weights = torch.where(causal_mask, attn_weights, mask_value)
  254. if attention_mask is not None:
  255. # Apply the attention mask
  256. attn_weights = attn_weights + attention_mask
  257. attn_weights = nn.Softmax(dim=-1)(attn_weights)
  258. # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
  259. if attn_weights.dtype != torch.float32:
  260. raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
  261. attn_weights = attn_weights.type(value.dtype)
  262. attn_weights = self.attn_dropout(attn_weights)
  263. # Mask heads if we want to
  264. if head_mask is not None:
  265. attn_weights = attn_weights * head_mask
  266. attn_output = torch.matmul(attn_weights, value)
  267. return attn_output, attn_weights
  268. def _split_heads(self, tensor, num_heads, attn_head_size):
  269. """
  270. Splits hidden_size dim into attn_head_size and num_heads
  271. """
  272. new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
  273. tensor = tensor.view(*new_shape)
  274. return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
  275. def _merge_heads(self, tensor, num_heads, attn_head_size):
  276. """
  277. Merges attn_head_size dim and num_attn_heads dim into hidden_size
  278. """
  279. tensor = tensor.permute(0, 2, 1, 3).contiguous()
  280. new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
  281. return tensor.view(new_shape)
  282. def forward(
  283. self,
  284. hidden_states: torch.Tensor,
  285. layer_past: Optional[bool] = None,
  286. attention_mask: Optional[torch.Tensor] = None,
  287. head_mask: Optional[torch.Tensor] = None,
  288. encoder_hidden_states: Optional[torch.Tensor] = None,
  289. encoder_attention_mask: Optional[torch.Tensor] = None,
  290. use_cache: Optional[bool] = False,
  291. output_attentions: Optional[bool] = False,
  292. ) -> tuple:
  293. if encoder_hidden_states is not None:
  294. if not hasattr(self, "q_attn"):
  295. raise ValueError(
  296. "If class is used as cross attention, the weights `q_attn` have to be defined. "
  297. "Please make sure to instantiate class with `ImageGPTAttention(..., is_cross_attention=True)`."
  298. )
  299. query = self.q_attn(hidden_states)
  300. key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
  301. attention_mask = encoder_attention_mask
  302. else:
  303. query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
  304. query = self._split_heads(query, self.num_heads, self.head_dim)
  305. key = self._split_heads(key, self.num_heads, self.head_dim)
  306. value = self._split_heads(value, self.num_heads, self.head_dim)
  307. if layer_past is not None:
  308. past_key, past_value = layer_past
  309. key = torch.cat((past_key, key), dim=-2)
  310. value = torch.cat((past_value, value), dim=-2)
  311. if use_cache is True:
  312. present = (key, value)
  313. else:
  314. present = None
  315. if self.reorder_and_upcast_attn:
  316. attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
  317. else:
  318. attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
  319. attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
  320. attn_output = self.c_proj(attn_output)
  321. attn_output = self.resid_dropout(attn_output)
  322. outputs = (attn_output, present)
  323. if output_attentions:
  324. outputs += (attn_weights,)
  325. return outputs # a, present, (attentions)
  326. class ImageGPTMLP(nn.Module):
  327. def __init__(self, intermediate_size, config):
  328. super().__init__()
  329. embed_dim = config.hidden_size
  330. self.c_fc = Conv1D(intermediate_size, embed_dim)
  331. self.c_proj = Conv1D(embed_dim, intermediate_size)
  332. self.act = ACT2FN[config.activation_function]
  333. self.dropout = nn.Dropout(config.resid_pdrop)
  334. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  335. hidden_states = self.c_fc(hidden_states)
  336. hidden_states = self.act(hidden_states)
  337. hidden_states = self.c_proj(hidden_states)
  338. hidden_states = self.dropout(hidden_states)
  339. return hidden_states
  340. class ImageGPTBlock(nn.Module):
  341. def __init__(self, config, layer_idx=None):
  342. super().__init__()
  343. hidden_size = config.hidden_size
  344. inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
  345. self.ln_1 = ImageGPTLayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  346. self.attn = ImageGPTAttention(config, layer_idx=layer_idx)
  347. self.ln_2 = ImageGPTLayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  348. if config.add_cross_attention:
  349. self.crossattention = ImageGPTAttention(config, is_cross_attention=True, layer_idx=layer_idx)
  350. self.ln_cross_attn = ImageGPTLayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  351. self.mlp = ImageGPTMLP(inner_dim, config)
  352. def forward(
  353. self,
  354. hidden_states: torch.Tensor,
  355. layer_past: Optional[bool] = None,
  356. attention_mask: Optional[torch.Tensor] = None,
  357. head_mask: Optional[torch.Tensor] = None,
  358. encoder_hidden_states: Optional[torch.Tensor] = None,
  359. encoder_attention_mask: Optional[torch.Tensor] = None,
  360. use_cache: Optional[bool] = False,
  361. output_attentions: Optional[bool] = False,
  362. ) -> tuple:
  363. residual = hidden_states
  364. hidden_states = self.ln_1(hidden_states)
  365. attn_outputs = self.attn(
  366. hidden_states,
  367. layer_past=layer_past,
  368. attention_mask=attention_mask,
  369. head_mask=head_mask,
  370. use_cache=use_cache,
  371. output_attentions=output_attentions,
  372. )
  373. attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
  374. outputs = attn_outputs[1:]
  375. # residual connection
  376. hidden_states = attn_output + residual
  377. if encoder_hidden_states is not None:
  378. # add one self-attention block for cross-attention
  379. if not hasattr(self, "crossattention"):
  380. raise ValueError(
  381. f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
  382. "cross-attention layers by setting `config.add_cross_attention=True`"
  383. )
  384. residual = hidden_states
  385. hidden_states = self.ln_cross_attn(hidden_states)
  386. cross_attn_outputs = self.crossattention(
  387. hidden_states,
  388. attention_mask=attention_mask,
  389. head_mask=head_mask,
  390. encoder_hidden_states=encoder_hidden_states,
  391. encoder_attention_mask=encoder_attention_mask,
  392. output_attentions=output_attentions,
  393. )
  394. attn_output = cross_attn_outputs[0]
  395. # residual connection
  396. hidden_states = residual + attn_output
  397. outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
  398. residual = hidden_states
  399. hidden_states = self.ln_2(hidden_states)
  400. feed_forward_hidden_states = self.mlp(hidden_states)
  401. # residual connection
  402. hidden_states = residual + feed_forward_hidden_states
  403. outputs = (hidden_states,) + (outputs if use_cache else outputs[1:])
  404. return outputs # hidden_states, present, (attentions, cross_attentions)
  405. class ImageGPTPreTrainedModel(PreTrainedModel):
  406. """
  407. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  408. models.
  409. """
  410. config_class = ImageGPTConfig
  411. load_tf_weights = load_tf_weights_in_imagegpt
  412. base_model_prefix = "transformer"
  413. main_input_name = "input_ids"
  414. supports_gradient_checkpointing = True
  415. _no_split_modules = ["ImageGPTBlock"]
  416. def __init__(self, *inputs, **kwargs):
  417. super().__init__(*inputs, **kwargs)
  418. def _init_weights(self, module):
  419. """Initialize the weights."""
  420. if isinstance(module, (nn.Linear, Conv1D)):
  421. # Slightly different from the TF version which uses truncated_normal for initialization
  422. # cf https://github.com/pytorch/pytorch/pull/5617
  423. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  424. if module.bias is not None:
  425. module.bias.data.zero_()
  426. elif isinstance(module, nn.Embedding):
  427. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  428. if module.padding_idx is not None:
  429. module.weight.data[module.padding_idx].zero_()
  430. elif isinstance(module, ImageGPTLayerNorm):
  431. module.weight.data.fill_(1.0)
  432. # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
  433. # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
  434. # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
  435. # > -- GPT-2 :: https://openai.com/blog/better-language-models/
  436. #
  437. # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
  438. for name, p in module.named_parameters():
  439. if "c_proj" in name and "weight" in name:
  440. # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
  441. p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
  442. IMAGEGPT_START_DOCSTRING = r"""
  443. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  444. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  445. etc.)
  446. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  447. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  448. and behavior.
  449. Parameters:
  450. config ([`ImageGPTConfig`]): Model configuration class with all the parameters of the model.
  451. Initializing with a config file does not load the weights associated with the model, only the
  452. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  453. """
  454. IMAGEGPT_INPUTS_DOCSTRING = r"""
  455. Args:
  456. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  457. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
  458. `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
  459. sequence tokens in the vocabulary.
  460. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  461. `input_ids`.
  462. Indices can be obtained using [`AutoImageProcessor`]. See [`ImageGPTImageProcessor.__call__`] for details.
  463. past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
  464. Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
  465. `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
  466. their past given to this model should not be passed as `input_ids` as they have already been computed.
  467. attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  468. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  469. - 1 for tokens that are **not masked**,
  470. - 0 for tokens that are **masked**.
  471. [What are attention masks?](../glossary#attention-mask)
  472. token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  473. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  474. 1]`:
  475. - 0 corresponds to a *sentence A* token,
  476. - 1 corresponds to a *sentence B* token.
  477. [What are token type IDs?](../glossary#token-type-ids)
  478. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  479. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  480. config.max_position_embeddings - 1]`.
  481. [What are position IDs?](../glossary#position-ids)
  482. head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  483. Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
  484. - 1 indicates the head is **not masked**,
  485. - 0 indicates the head is **masked**.
  486. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  487. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  488. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  489. model's internal embedding lookup matrix.
  490. If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
  491. `past_key_values`).
  492. use_cache (`bool`, *optional*):
  493. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  494. `past_key_values`).
  495. output_attentions (`bool`, *optional*):
  496. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  497. tensors for more detail.
  498. output_hidden_states (`bool`, *optional*):
  499. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  500. more detail.
  501. return_dict (`bool`, *optional*):
  502. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  503. """
  504. @add_start_docstrings(
  505. "The bare ImageGPT Model transformer outputting raw hidden-states without any specific head on top.",
  506. IMAGEGPT_START_DOCSTRING,
  507. )
  508. class ImageGPTModel(ImageGPTPreTrainedModel):
  509. def __init__(self, config: ImageGPTConfig):
  510. super().__init__(config)
  511. self.embed_dim = config.hidden_size
  512. self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
  513. self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
  514. self.drop = nn.Dropout(config.embd_pdrop)
  515. self.h = nn.ModuleList([ImageGPTBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)])
  516. self.ln_f = ImageGPTLayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  517. # Model parallel
  518. self.model_parallel = False
  519. self.device_map = None
  520. self.gradient_checkpointing = False
  521. # Initialize weights and apply final processing
  522. self.post_init()
  523. def get_input_embeddings(self):
  524. return self.wte
  525. def set_input_embeddings(self, new_embeddings):
  526. self.wte = new_embeddings
  527. def _prune_heads(self, heads_to_prune):
  528. """
  529. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
  530. """
  531. for layer, heads in heads_to_prune.items():
  532. self.h[layer].attn.prune_heads(heads)
  533. @add_start_docstrings_to_model_forward(IMAGEGPT_INPUTS_DOCSTRING)
  534. @replace_return_docstrings(output_type=BaseModelOutputWithPastAndCrossAttentions, config_class=_CONFIG_FOR_DOC)
  535. def forward(
  536. self,
  537. input_ids: Optional[torch.Tensor] = None,
  538. past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
  539. attention_mask: Optional[torch.Tensor] = None,
  540. token_type_ids: Optional[torch.Tensor] = None,
  541. position_ids: Optional[torch.Tensor] = None,
  542. head_mask: Optional[torch.Tensor] = None,
  543. inputs_embeds: Optional[torch.Tensor] = None,
  544. encoder_hidden_states: Optional[torch.Tensor] = None,
  545. encoder_attention_mask: Optional[torch.Tensor] = None,
  546. use_cache: Optional[bool] = None,
  547. output_attentions: Optional[bool] = None,
  548. output_hidden_states: Optional[bool] = None,
  549. return_dict: Optional[bool] = None,
  550. **kwargs: Any,
  551. ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
  552. r"""
  553. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  554. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  555. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  556. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  557. Returns:
  558. Examples:
  559. ```python
  560. >>> from transformers import AutoImageProcessor, ImageGPTModel
  561. >>> from PIL import Image
  562. >>> import requests
  563. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  564. >>> image = Image.open(requests.get(url, stream=True).raw)
  565. >>> image_processor = AutoImageProcessor.from_pretrained("openai/imagegpt-small")
  566. >>> model = ImageGPTModel.from_pretrained("openai/imagegpt-small")
  567. >>> inputs = image_processor(images=image, return_tensors="pt")
  568. >>> outputs = model(**inputs)
  569. >>> last_hidden_states = outputs.last_hidden_state
  570. ```"""
  571. if "pixel_values" in kwargs:
  572. warnings.warn(
  573. "The `pixel_values` argument is deprecated and will be removed in v4.47, use `input_ids` instead.",
  574. FutureWarning,
  575. )
  576. if input_ids is not None:
  577. raise ValueError(
  578. "You cannot pass both `pixel_values` and `input_ids`. Please make sure to only pass `input_ids`."
  579. )
  580. input_ids = kwargs.pop("pixel_values")
  581. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  582. output_hidden_states = (
  583. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  584. )
  585. use_cache = use_cache if use_cache is not None else self.config.use_cache
  586. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  587. if input_ids is not None and inputs_embeds is not None:
  588. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  589. elif input_ids is not None:
  590. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  591. input_shape = input_ids.size()
  592. input_ids = input_ids.view(-1, input_shape[-1])
  593. batch_size = input_ids.shape[0]
  594. elif inputs_embeds is not None:
  595. input_shape = inputs_embeds.size()[:-1]
  596. batch_size = inputs_embeds.shape[0]
  597. else:
  598. raise ValueError("You have to specify either input_ids or inputs_embeds")
  599. device = input_ids.device if input_ids is not None else inputs_embeds.device
  600. if token_type_ids is not None:
  601. token_type_ids = token_type_ids.view(-1, input_shape[-1])
  602. if past_key_values is None:
  603. past_length = 0
  604. past_key_values = tuple([None] * len(self.h))
  605. else:
  606. past_length = past_key_values[0][0].size(-2)
  607. if position_ids is None:
  608. position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
  609. position_ids = position_ids.unsqueeze(0)
  610. # ImageGPTAttention mask.
  611. if attention_mask is not None:
  612. if batch_size <= 0:
  613. raise ValueError("batch_size has to be defined and > 0")
  614. attention_mask = attention_mask.view(batch_size, -1)
  615. # We create a 3D attention mask from a 2D tensor mask.
  616. # Sizes are [batch_size, 1, 1, to_seq_length]
  617. # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
  618. # this attention mask is more simple than the triangular masking of causal attention
  619. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
  620. attention_mask = attention_mask[:, None, None, :]
  621. # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
  622. # masked positions, this operation will create a tensor which is 0.0 for
  623. # positions we want to attend and the dtype's smallest value for masked positions.
  624. # Since we are adding it to the raw scores before the softmax, this is
  625. # effectively the same as removing these entirely.
  626. attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
  627. attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
  628. # If a 2D or 3D attention mask is provided for the cross-attention
  629. # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
  630. if self.config.add_cross_attention and encoder_hidden_states is not None:
  631. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
  632. encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
  633. if encoder_attention_mask is None:
  634. encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
  635. encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  636. else:
  637. encoder_attention_mask = None
  638. # Prepare head mask if needed
  639. # 1.0 in head_mask indicate we keep the head
  640. # attention_probs has shape bsz x n_heads x N x N
  641. # head_mask has shape n_layer x batch x n_heads x N x N
  642. head_mask = self.get_head_mask(head_mask, self.config.n_layer)
  643. if inputs_embeds is None:
  644. inputs_embeds = self.wte(input_ids)
  645. position_embeds = self.wpe(position_ids)
  646. hidden_states = inputs_embeds + position_embeds
  647. if token_type_ids is not None:
  648. token_type_embeds = self.wte(token_type_ids)
  649. hidden_states = hidden_states + token_type_embeds
  650. hidden_states = self.drop(hidden_states)
  651. output_shape = input_shape + (hidden_states.size(-1),)
  652. if self.gradient_checkpointing and self.training:
  653. if use_cache:
  654. logger.warning_once(
  655. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  656. )
  657. use_cache = False
  658. presents = () if use_cache else None
  659. all_self_attentions = () if output_attentions else None
  660. all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
  661. all_hidden_states = () if output_hidden_states else None
  662. for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
  663. # Model parallel
  664. if self.model_parallel:
  665. torch.cuda.set_device(hidden_states.device)
  666. # Ensure layer_past is on same device as hidden_states (might not be correct)
  667. if layer_past is not None:
  668. layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
  669. # Ensure that attention_mask is always on the same device as hidden_states
  670. if attention_mask is not None:
  671. attention_mask = attention_mask.to(hidden_states.device)
  672. if isinstance(head_mask, torch.Tensor):
  673. head_mask = head_mask.to(hidden_states.device)
  674. if output_hidden_states:
  675. all_hidden_states = all_hidden_states + (hidden_states,)
  676. if self.gradient_checkpointing and self.training:
  677. outputs = self._gradient_checkpointing_func(
  678. block.__call__,
  679. hidden_states,
  680. None,
  681. attention_mask,
  682. head_mask[i],
  683. encoder_hidden_states,
  684. encoder_attention_mask,
  685. use_cache,
  686. output_attentions,
  687. )
  688. else:
  689. outputs = block(
  690. hidden_states,
  691. layer_past=layer_past,
  692. attention_mask=attention_mask,
  693. head_mask=head_mask[i],
  694. encoder_hidden_states=encoder_hidden_states,
  695. encoder_attention_mask=encoder_attention_mask,
  696. use_cache=use_cache,
  697. output_attentions=output_attentions,
  698. )
  699. hidden_states = outputs[0]
  700. if use_cache is True:
  701. presents = presents + (outputs[1],)
  702. if output_attentions:
  703. all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
  704. if self.config.add_cross_attention:
  705. all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
  706. # Model Parallel: If it's the last layer for that device, put things on the next device
  707. if self.model_parallel:
  708. for k, v in self.device_map.items():
  709. if i == v[-1] and "cuda:" + str(k) != self.last_device:
  710. hidden_states = hidden_states.to("cuda:" + str(k + 1))
  711. hidden_states = self.ln_f(hidden_states)
  712. hidden_states = hidden_states.view(*output_shape)
  713. # Add last hidden state
  714. if output_hidden_states:
  715. all_hidden_states = all_hidden_states + (hidden_states,)
  716. if not return_dict:
  717. return tuple(
  718. v
  719. for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
  720. if v is not None
  721. )
  722. return BaseModelOutputWithPastAndCrossAttentions(
  723. last_hidden_state=hidden_states,
  724. past_key_values=presents,
  725. hidden_states=all_hidden_states,
  726. attentions=all_self_attentions,
  727. cross_attentions=all_cross_attentions,
  728. )
  729. @add_start_docstrings(
  730. """
  731. The ImageGPT Model transformer with a language modeling head on top (linear layer with weights tied to the input
  732. embeddings).
  733. """,
  734. IMAGEGPT_START_DOCSTRING,
  735. )
  736. class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel, GenerationMixin):
  737. _tied_weights_keys = ["lm_head.weight"]
  738. def __init__(self, config: ImageGPTConfig):
  739. super().__init__(config)
  740. self.transformer = ImageGPTModel(config)
  741. self.lm_head = nn.Linear(config.n_embd, config.vocab_size - 1, bias=False)
  742. # Model parallel
  743. self.model_parallel = False
  744. self.device_map = None
  745. # Initialize weights and apply final processing
  746. self.post_init()
  747. def get_output_embeddings(self):
  748. return self.lm_head
  749. def set_output_embeddings(self, new_embeddings):
  750. self.lm_head = new_embeddings
  751. @add_start_docstrings_to_model_forward(IMAGEGPT_INPUTS_DOCSTRING)
  752. @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
  753. def forward(
  754. self,
  755. input_ids: Optional[torch.Tensor] = None,
  756. past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
  757. attention_mask: Optional[torch.Tensor] = None,
  758. token_type_ids: Optional[torch.Tensor] = None,
  759. position_ids: Optional[torch.Tensor] = None,
  760. head_mask: Optional[torch.Tensor] = None,
  761. inputs_embeds: Optional[torch.Tensor] = None,
  762. encoder_hidden_states: Optional[torch.Tensor] = None,
  763. encoder_attention_mask: Optional[torch.Tensor] = None,
  764. labels: Optional[torch.Tensor] = None,
  765. use_cache: Optional[bool] = None,
  766. output_attentions: Optional[bool] = None,
  767. output_hidden_states: Optional[bool] = None,
  768. return_dict: Optional[bool] = None,
  769. **kwargs: Any,
  770. ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
  771. r"""
  772. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  773. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  774. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  775. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  776. Returns:
  777. Examples:
  778. ```python
  779. >>> from transformers import AutoImageProcessor, ImageGPTForCausalImageModeling
  780. >>> import torch
  781. >>> import matplotlib.pyplot as plt
  782. >>> import numpy as np
  783. >>> image_processor = AutoImageProcessor.from_pretrained("openai/imagegpt-small")
  784. >>> model = ImageGPTForCausalImageModeling.from_pretrained("openai/imagegpt-small")
  785. >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  786. >>> model.to(device) # doctest: +IGNORE_RESULT
  787. >>> # unconditional generation of 8 images
  788. >>> batch_size = 4
  789. >>> context = torch.full((batch_size, 1), model.config.vocab_size - 1) # initialize with SOS token
  790. >>> context = context.to(device)
  791. >>> output = model.generate(
  792. ... input_ids=context, max_length=model.config.n_positions + 1, temperature=1.0, do_sample=True, top_k=40
  793. ... )
  794. >>> clusters = image_processor.clusters
  795. >>> height = image_processor.size["height"]
  796. >>> width = image_processor.size["width"]
  797. >>> samples = output[:, 1:].cpu().detach().numpy()
  798. >>> samples_img = [
  799. ... np.reshape(np.rint(127.5 * (clusters[s] + 1.0)), [height, width, 3]).astype(np.uint8) for s in samples
  800. ... ] # convert color cluster tokens back to pixels
  801. >>> f, axes = plt.subplots(1, batch_size, dpi=300)
  802. >>> for img, ax in zip(samples_img, axes): # doctest: +IGNORE_RESULT
  803. ... ax.axis("off")
  804. ... ax.imshow(img)
  805. ```"""
  806. if "pixel_values" in kwargs:
  807. warnings.warn(
  808. "The `pixel_values` argument is deprecated and will be removed in v4.47, use `input_ids` instead.",
  809. FutureWarning,
  810. )
  811. if input_ids is not None:
  812. raise ValueError(
  813. "You cannot pass both `pixel_values` and `input_ids`. Please make sure to only pass `input_ids`."
  814. )
  815. input_ids = kwargs.pop("pixel_values")
  816. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  817. transformer_outputs = self.transformer(
  818. input_ids,
  819. past_key_values=past_key_values,
  820. attention_mask=attention_mask,
  821. token_type_ids=token_type_ids,
  822. position_ids=position_ids,
  823. head_mask=head_mask,
  824. inputs_embeds=inputs_embeds,
  825. encoder_hidden_states=encoder_hidden_states,
  826. encoder_attention_mask=encoder_attention_mask,
  827. use_cache=use_cache,
  828. output_attentions=output_attentions,
  829. output_hidden_states=output_hidden_states,
  830. return_dict=return_dict,
  831. )
  832. hidden_states = transformer_outputs[0]
  833. lm_logits = self.lm_head(hidden_states)
  834. loss = None
  835. if labels is not None:
  836. # Shift so that tokens < n predict n
  837. shift_logits = lm_logits[..., :-1, :].contiguous()
  838. shift_labels = labels[..., 1:].contiguous()
  839. # Flatten the tokens
  840. loss_fct = CrossEntropyLoss()
  841. loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
  842. if not return_dict:
  843. output = (lm_logits,) + transformer_outputs[1:]
  844. return ((loss,) + output) if loss is not None else output
  845. return CausalLMOutputWithCrossAttentions(
  846. loss=loss,
  847. logits=lm_logits,
  848. past_key_values=transformer_outputs.past_key_values,
  849. hidden_states=transformer_outputs.hidden_states,
  850. attentions=transformer_outputs.attentions,
  851. cross_attentions=transformer_outputs.cross_attentions,
  852. )
  853. @staticmethod
  854. def _reorder_cache(
  855. past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
  856. ) -> Tuple[Tuple[torch.Tensor]]:
  857. """
  858. This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
  859. [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
  860. beam_idx at every generation step.
  861. """
  862. return tuple(
  863. tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
  864. for layer_past in past_key_values
  865. )
  866. @add_start_docstrings(
  867. """
  868. The ImageGPT Model transformer with an image classification head on top (linear layer).
  869. [`ImageGPTForImageClassification`] average-pools the hidden states in order to do the classification.
  870. """,
  871. IMAGEGPT_START_DOCSTRING,
  872. )
  873. class ImageGPTForImageClassification(ImageGPTPreTrainedModel):
  874. def __init__(self, config: ImageGPTConfig):
  875. super().__init__(config)
  876. self.num_labels = config.num_labels
  877. self.transformer = ImageGPTModel(config)
  878. self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
  879. # Initialize weights and apply final processing
  880. self.post_init()
  881. @add_start_docstrings_to_model_forward(IMAGEGPT_INPUTS_DOCSTRING)
  882. @replace_return_docstrings(output_type=SequenceClassifierOutputWithPast, config_class=_CONFIG_FOR_DOC)
  883. def forward(
  884. self,
  885. input_ids: Optional[torch.Tensor] = None,
  886. past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
  887. attention_mask: Optional[torch.Tensor] = None,
  888. token_type_ids: Optional[torch.Tensor] = None,
  889. position_ids: Optional[torch.Tensor] = None,
  890. head_mask: Optional[torch.Tensor] = None,
  891. inputs_embeds: Optional[torch.Tensor] = None,
  892. labels: Optional[torch.Tensor] = None,
  893. use_cache: Optional[bool] = None,
  894. output_attentions: Optional[bool] = None,
  895. output_hidden_states: Optional[bool] = None,
  896. return_dict: Optional[bool] = None,
  897. **kwargs: Any,
  898. ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
  899. r"""
  900. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  901. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  902. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  903. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  904. Returns:
  905. Examples:
  906. ```python
  907. >>> from transformers import AutoImageProcessor, ImageGPTForImageClassification
  908. >>> from PIL import Image
  909. >>> import requests
  910. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  911. >>> image = Image.open(requests.get(url, stream=True).raw)
  912. >>> image_processor = AutoImageProcessor.from_pretrained("openai/imagegpt-small")
  913. >>> model = ImageGPTForImageClassification.from_pretrained("openai/imagegpt-small")
  914. >>> inputs = image_processor(images=image, return_tensors="pt")
  915. >>> outputs = model(**inputs)
  916. >>> logits = outputs.logits
  917. ```"""
  918. if "pixel_values" in kwargs:
  919. warnings.warn(
  920. "The `pixel_values` argument is deprecated and will be removed in v4.47, use `input_ids` instead.",
  921. FutureWarning,
  922. )
  923. if input_ids is not None:
  924. raise ValueError(
  925. "You cannot pass both `pixel_values` and `input_ids`. Please make sure to only pass `input_ids`."
  926. )
  927. input_ids = kwargs.pop("pixel_values")
  928. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  929. transformer_outputs = self.transformer(
  930. input_ids,
  931. past_key_values=past_key_values,
  932. attention_mask=attention_mask,
  933. token_type_ids=token_type_ids,
  934. position_ids=position_ids,
  935. head_mask=head_mask,
  936. inputs_embeds=inputs_embeds,
  937. use_cache=use_cache,
  938. output_attentions=output_attentions,
  939. output_hidden_states=output_hidden_states,
  940. return_dict=return_dict,
  941. )
  942. hidden_states = transformer_outputs[0]
  943. # average-pool the hidden states along the sequence dimension
  944. pooled_hidden_states = hidden_states.mean(dim=1)
  945. # project from (batch_size, hidden_size) to (batch_size, num_labels)
  946. logits = self.score(pooled_hidden_states)
  947. loss = None
  948. if labels is not None:
  949. if self.config.problem_type is None:
  950. if self.num_labels == 1:
  951. self.config.problem_type = "regression"
  952. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  953. self.config.problem_type = "single_label_classification"
  954. else:
  955. self.config.problem_type = "multi_label_classification"
  956. if self.config.problem_type == "regression":
  957. loss_fct = MSELoss()
  958. if self.num_labels == 1:
  959. loss = loss_fct(logits.squeeze(), labels.squeeze())
  960. else:
  961. loss = loss_fct(logits, labels)
  962. elif self.config.problem_type == "single_label_classification":
  963. loss_fct = CrossEntropyLoss()
  964. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  965. elif self.config.problem_type == "multi_label_classification":
  966. loss_fct = BCEWithLogitsLoss()
  967. loss = loss_fct(logits, labels)
  968. if not return_dict:
  969. output = (logits,) + transformer_outputs[1:]
  970. return ((loss,) + output) if loss is not None else output
  971. return SequenceClassifierOutputWithPast(
  972. loss=loss,
  973. logits=logits,
  974. past_key_values=transformer_outputs.past_key_values,
  975. hidden_states=transformer_outputs.hidden_states,
  976. attentions=transformer_outputs.attentions,
  977. )