modeling_patchtst.py 90 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032
  1. # coding=utf-8
  2. # Copyright 2023 IBM & Hugging Face. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch PatchTST model."""
  16. import math
  17. from dataclasses import dataclass
  18. from typing import Optional, Tuple, Union
  19. import torch
  20. from torch import nn
  21. from ...activations import ACT2CLS
  22. from ...modeling_outputs import BaseModelOutput
  23. from ...modeling_utils import PreTrainedModel
  24. from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput
  25. from ...utils import ModelOutput, add_start_docstrings, logging
  26. from .configuration_patchtst import PatchTSTConfig
  27. logger = logging.get_logger(__name__)
  28. _CONFIG_FOR_DOC = "PatchTSTConfig"
  29. # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->PatchTST
  30. class PatchTSTAttention(nn.Module):
  31. """Multi-headed attention from 'Attention Is All You Need' paper"""
  32. def __init__(
  33. self,
  34. embed_dim: int,
  35. num_heads: int,
  36. dropout: float = 0.0,
  37. is_decoder: bool = False,
  38. bias: bool = True,
  39. is_causal: bool = False,
  40. config: Optional[PatchTSTConfig] = None,
  41. ):
  42. super().__init__()
  43. self.embed_dim = embed_dim
  44. self.num_heads = num_heads
  45. self.dropout = dropout
  46. self.head_dim = embed_dim // num_heads
  47. self.config = config
  48. if (self.head_dim * num_heads) != self.embed_dim:
  49. raise ValueError(
  50. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  51. f" and `num_heads`: {num_heads})."
  52. )
  53. self.scaling = self.head_dim**-0.5
  54. self.is_decoder = is_decoder
  55. self.is_causal = is_causal
  56. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  57. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  58. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  59. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  60. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  61. return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
  62. def forward(
  63. self,
  64. hidden_states: torch.Tensor,
  65. key_value_states: Optional[torch.Tensor] = None,
  66. past_key_value: Optional[Tuple[torch.Tensor]] = None,
  67. attention_mask: Optional[torch.Tensor] = None,
  68. layer_head_mask: Optional[torch.Tensor] = None,
  69. output_attentions: bool = False,
  70. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  71. """Input shape: Batch x Time x Channel"""
  72. # if key_value_states are provided this layer is used as a cross-attention layer
  73. # for the decoder
  74. is_cross_attention = key_value_states is not None
  75. bsz, tgt_len, _ = hidden_states.size()
  76. # get query proj
  77. query_states = self.q_proj(hidden_states) * self.scaling
  78. # get key, value proj
  79. # `past_key_value[0].shape[2] == key_value_states.shape[1]`
  80. # is checking that the `sequence_length` of the `past_key_value` is the same as
  81. # the provided `key_value_states` to support prefix tuning
  82. if (
  83. is_cross_attention
  84. and past_key_value is not None
  85. and past_key_value[0].shape[2] == key_value_states.shape[1]
  86. ):
  87. # reuse k,v, cross_attentions
  88. key_states = past_key_value[0]
  89. value_states = past_key_value[1]
  90. elif is_cross_attention:
  91. # cross_attentions
  92. key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
  93. value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
  94. elif past_key_value is not None:
  95. # reuse k, v, self_attention
  96. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  97. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  98. key_states = torch.cat([past_key_value[0], key_states], dim=2)
  99. value_states = torch.cat([past_key_value[1], value_states], dim=2)
  100. else:
  101. # self_attention
  102. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  103. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  104. if self.is_decoder:
  105. # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
  106. # Further calls to cross_attention layer can then reuse all cross-attention
  107. # key/value_states (first "if" case)
  108. # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
  109. # all previous decoder key/value_states. Further calls to uni-directional self-attention
  110. # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
  111. # if encoder bi-directional self-attention `past_key_value` is always `None`
  112. past_key_value = (key_states, value_states)
  113. proj_shape = (bsz * self.num_heads, -1, self.head_dim)
  114. query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
  115. key_states = key_states.reshape(*proj_shape)
  116. value_states = value_states.reshape(*proj_shape)
  117. src_len = key_states.size(1)
  118. attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
  119. if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
  120. raise ValueError(
  121. f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
  122. f" {attn_weights.size()}"
  123. )
  124. if attention_mask is not None:
  125. if attention_mask.size() != (bsz, 1, tgt_len, src_len):
  126. raise ValueError(
  127. f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
  128. )
  129. attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
  130. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  131. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  132. if layer_head_mask is not None:
  133. if layer_head_mask.size() != (self.num_heads,):
  134. raise ValueError(
  135. f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
  136. f" {layer_head_mask.size()}"
  137. )
  138. attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  139. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  140. if output_attentions:
  141. # this operation is a bit awkward, but it's required to
  142. # make sure that attn_weights keeps its gradient.
  143. # In order to do so, attn_weights have to be reshaped
  144. # twice and have to be reused in the following
  145. attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  146. attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
  147. else:
  148. attn_weights_reshaped = None
  149. attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  150. attn_output = torch.bmm(attn_probs, value_states)
  151. if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
  152. raise ValueError(
  153. f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
  154. f" {attn_output.size()}"
  155. )
  156. attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
  157. attn_output = attn_output.transpose(1, 2)
  158. # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
  159. # partitioned across GPUs when using tensor-parallelism.
  160. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
  161. attn_output = self.out_proj(attn_output)
  162. return attn_output, attn_weights_reshaped, past_key_value
  163. class PatchTSTBatchNorm(nn.Module):
  164. """
  165. Compute batch normalization over the sequence length (time) dimension.
  166. """
  167. def __init__(self, config: PatchTSTConfig):
  168. super().__init__()
  169. self.batchnorm = nn.BatchNorm1d(config.d_model, eps=config.norm_eps)
  170. def forward(self, inputs: torch.Tensor):
  171. """
  172. Parameters:
  173. inputs (`torch.Tensor` of shape `(batch_size, sequence_length, d_model)`):
  174. input for Batch norm calculation
  175. Returns:
  176. `torch.Tensor` of shape `(batch_size, sequence_length, d_model)`
  177. """
  178. output = inputs.transpose(1, 2) # output: (batch_size, d_model, sequence_length)
  179. output = self.batchnorm(output)
  180. return output.transpose(1, 2)
  181. def random_masking(
  182. inputs: torch.Tensor,
  183. mask_ratio: float,
  184. unmasked_channel_indices: list = None,
  185. channel_consistent_masking: bool = False,
  186. mask_value: int = 0,
  187. ):
  188. """random_masking: Mask the input considering the control variables.
  189. Args:
  190. inputs (`torch.Tensor` of shape `(batch_size, num_channels, sequence_length, num_features)`):
  191. The input tensor to mask.
  192. mask_ratio (`float`):
  193. Masking ratio applied to mask the input data during random pretraining. It is the number between 0 and 1.
  194. unmasked_channel_indices (list, *optional*):
  195. Indices of channels that will not be masked.
  196. channel_consistent_masking (bool, *optional*, defaults to `False`):
  197. When true, masking will be same across all channels of a timeseries. Otherwise, masking positions will vary
  198. across channels.
  199. mask_value (int, *optional*, defaults to 0):
  200. Define the value of masked patches for pretraining.
  201. Returns:
  202. `tuple(torch.Tensor)`: inputs_mask, masked input, same shape as input Tensor and mask tensor of shape [bs x c x
  203. n]
  204. """
  205. if mask_ratio < 0 or mask_ratio >= 1:
  206. raise ValueError(f"Mask ratio {mask_ratio} has to be between 0 and 1.")
  207. batch_size, num_channels, sequence_length, num_features = inputs.shape
  208. device = inputs.device
  209. len_keep = int(sequence_length * (1 - mask_ratio))
  210. if channel_consistent_masking:
  211. noise = torch.rand(batch_size, 1, sequence_length, device=device) # noise in [0, 1], bs x 1 x L
  212. noise = noise.repeat(1, num_channels, 1) # bs x num_channels x time
  213. else:
  214. # noise in [0, 1], bs x num_channels x L
  215. noise = torch.rand(batch_size, num_channels, sequence_length, device=device)
  216. # mask: [bs x num_channels x num_patch]
  217. mask = torch.ones(batch_size, num_channels, sequence_length, device=device)
  218. mask[:, :, :len_keep] = 0
  219. # sort noise for each sample
  220. ids_shuffle = torch.argsort(noise, dim=-1) # ascend: small is keep, large is remove
  221. ids_restore = torch.argsort(ids_shuffle, dim=-1) # ids_restore: [bs x num_channels x L]
  222. mask = torch.gather(mask, dim=-1, index=ids_restore)
  223. mask = mask.unsqueeze(-1).repeat(1, 1, 1, num_features) # mask: [bs x num_channels x num_patches x patch_length]
  224. if unmasked_channel_indices is not None:
  225. mask[:, unmasked_channel_indices, :, :] = 0
  226. inputs_mask = inputs.masked_fill(mask.bool(), mask_value)
  227. return inputs_mask, mask[..., 0]
  228. def forecast_masking(
  229. inputs: torch.Tensor,
  230. num_forecast_mask_patches: Union[list, int],
  231. unmasked_channel_indices: list = None,
  232. mask_value: int = 0,
  233. ):
  234. """Forecast masking that masks the last K patches where K is from the num_forecast_mask_patches.
  235. If num_forecast_mask_patches is a list, samples in the batch will be randomly masked by numbers defined in the list.
  236. Parameters:
  237. inputs (`torch.Tensor`):
  238. Input of shape `(bs, num_channels, num_patch, patch_length)`
  239. num_forecast_mask_patches (`list`):
  240. Number of patches to be masked at the end of each batch sample. e.g. 4 or [3, 5].
  241. unmasked_channel_indices (`list`, *optional*):
  242. Indices of channels that are not masked.
  243. mask_value (`int`, *optional*, defaults to 0):
  244. Values in the masked patches will be filled by `mask_value`.
  245. Returns:
  246. `tuple(torch.Tensor)`: inputs_mask, masked input, same shape as inputs Tensor and Mask tensor of shape `(bs,
  247. num_channels , num_patch)` or `(bs, tsg1, tsg2, num_channels, num_patch)`
  248. """
  249. if isinstance(num_forecast_mask_patches, int):
  250. num_forecast_mask_patches = [num_forecast_mask_patches]
  251. forecast_mask_ratios = [1 for _ in num_forecast_mask_patches]
  252. batch_size, num_channels, sequence_length, num_features = inputs.shape
  253. mask = torch.zeros(batch_size, num_channels, sequence_length, device=inputs.device)
  254. t_list = []
  255. total_length = 0
  256. total_ratio = sum(forecast_mask_ratios)
  257. for patch_length, ratio in zip(num_forecast_mask_patches, forecast_mask_ratios):
  258. if patch_length <= 0 or patch_length >= sequence_length:
  259. raise ValueError(
  260. f"num_forecast_mask_patches {patch_length} should be greater than 0 and less than total patches."
  261. )
  262. temp_len = int(batch_size * ratio / total_ratio)
  263. t_list.append([patch_length, ratio, temp_len])
  264. total_length += temp_len
  265. t_list = sorted(t_list, key=lambda x: x[2])
  266. if total_length < batch_size:
  267. t_list[0][2] = t_list[0][2] + (batch_size - total_length)
  268. elif total_length > batch_size:
  269. t_list[-1][2] = t_list[-1][2] + (total_length - batch_size)
  270. batch1 = 0
  271. for patch_len, _, temp_len in t_list:
  272. batch2 = batch1 + temp_len
  273. mask[batch1:batch2, :, -patch_len:] = 1
  274. batch1 = batch2
  275. perm = torch.randperm(mask.shape[0])
  276. mask = mask[perm]
  277. mask = mask.unsqueeze(-1).repeat(1, 1, 1, num_features) # mask: [bs x num_channels x num_patch x patch_len]
  278. if unmasked_channel_indices is not None:
  279. mask[:, unmasked_channel_indices, :, :] = 0
  280. inputs_mask = inputs.masked_fill(mask.bool(), mask_value)
  281. return inputs_mask, mask[..., 0]
  282. class PatchTSTPatchify(nn.Module):
  283. """
  284. A class to patchify the time series sequence into different patches
  285. Returns:
  286. `torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`
  287. """
  288. def __init__(self, config: PatchTSTConfig):
  289. super().__init__()
  290. self.sequence_length = config.context_length
  291. self.patch_length = config.patch_length
  292. self.patch_stride = config.patch_stride
  293. if self.sequence_length <= self.patch_length:
  294. raise ValueError(
  295. f"Sequence length ({self.sequence_length}) has to be greater than the patch length ({self.patch_length})"
  296. )
  297. # get the number of patches
  298. self.num_patches = (max(self.sequence_length, self.patch_length) - self.patch_length) // self.patch_stride + 1
  299. new_sequence_length = self.patch_length + self.patch_stride * (self.num_patches - 1)
  300. self.sequence_start = self.sequence_length - new_sequence_length
  301. def forward(self, past_values: torch.Tensor):
  302. """
  303. Parameters:
  304. past_values (`torch.Tensor` of shape `(batch_size, sequence_length, num_channels)`, *required*):
  305. Input for patchification
  306. Returns:
  307. `torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`
  308. """
  309. sequence_length = past_values.shape[-2]
  310. if sequence_length != self.sequence_length:
  311. raise ValueError(
  312. f"Input sequence length ({sequence_length}) doesn't match model configuration ({self.sequence_length})."
  313. )
  314. # output: [bs x new_sequence_length x num_channels]
  315. output = past_values[:, self.sequence_start :, :]
  316. # output: [bs x num_patches x num_input_channels x patch_length]
  317. output = output.unfold(dimension=-2, size=self.patch_length, step=self.patch_stride)
  318. # output: [bs x num_input_channels x num_patches x patch_length]
  319. output = output.transpose(-2, -3).contiguous()
  320. return output
  321. class PatchTSTMasking(nn.Module):
  322. """
  323. Class to perform random or forecast masking.
  324. Parameters:
  325. config (`PatchTSTConfig`): model config
  326. Returns:
  327. x_mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`)
  328. Masked patched input
  329. mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`)
  330. Bool tensor indicating True on masked points
  331. """
  332. def __init__(self, config: PatchTSTConfig):
  333. super().__init__()
  334. self.random_mask_ratio = config.random_mask_ratio
  335. self.channel_consistent_masking = config.channel_consistent_masking
  336. self.mask_type = config.mask_type
  337. self.num_forecast_mask_patches = config.num_forecast_mask_patches
  338. self.unmasked_channel_indices = config.unmasked_channel_indices
  339. self.mask_value = config.mask_value
  340. if self.unmasked_channel_indices is not None:
  341. self.unmasked_channel_indices = sorted(self.unmasked_channel_indices)
  342. def forward(self, patch_input: torch.Tensor):
  343. """
  344. Parameters:
  345. patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*):
  346. Patch input
  347. Return:
  348. masked_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`)
  349. Masked patched input
  350. mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`)
  351. Bool tensor indicating True on masked points
  352. """
  353. if self.mask_type == "random":
  354. masked_input, mask = random_masking(
  355. inputs=patch_input,
  356. mask_ratio=self.random_mask_ratio,
  357. unmasked_channel_indices=self.unmasked_channel_indices,
  358. channel_consistent_masking=self.channel_consistent_masking,
  359. mask_value=self.mask_value,
  360. )
  361. elif self.mask_type == "forecast":
  362. masked_input, mask = forecast_masking(
  363. inputs=patch_input,
  364. num_forecast_mask_patches=self.num_forecast_mask_patches,
  365. unmasked_channel_indices=self.unmasked_channel_indices,
  366. mask_value=self.mask_value,
  367. )
  368. else:
  369. raise ValueError(f"Invalid mask type {self.mask_type}.")
  370. # mask: [bs x num_input_channels x num_patch]
  371. mask = mask.bool()
  372. return masked_input, mask
  373. class PatchTSTEncoderLayer(nn.Module):
  374. """
  375. PatchTST encoder layer
  376. """
  377. def __init__(self, config: PatchTSTConfig):
  378. super().__init__()
  379. self.channel_attention = config.channel_attention
  380. # Multi-Head attention
  381. self.self_attn = PatchTSTAttention(
  382. embed_dim=config.d_model,
  383. num_heads=config.num_attention_heads,
  384. dropout=config.attention_dropout,
  385. )
  386. # Add & Norm of the sublayer 1
  387. self.dropout_path1 = nn.Dropout(config.path_dropout) if config.path_dropout > 0 else nn.Identity()
  388. if config.norm_type == "batchnorm":
  389. self.norm_sublayer1 = PatchTSTBatchNorm(config)
  390. elif config.norm_type == "layernorm":
  391. self.norm_sublayer1 = nn.LayerNorm(config.d_model, eps=config.norm_eps)
  392. else:
  393. raise ValueError(f"{config.norm_type} is not a supported norm layer type.")
  394. # Add & Norm of the sublayer 2
  395. if self.channel_attention:
  396. self.dropout_path2 = nn.Dropout(config.path_dropout) if config.path_dropout > 0 else nn.Identity()
  397. if config.norm_type == "batchnorm":
  398. self.norm_sublayer2 = PatchTSTBatchNorm(config)
  399. elif config.norm_type == "layernorm":
  400. self.norm_sublayer2 = nn.LayerNorm(config.d_model, eps=config.norm_eps)
  401. else:
  402. raise ValueError(f"{config.norm_type} is not a supported norm layer type.")
  403. # Position-wise Feed-Forward
  404. self.ff = nn.Sequential(
  405. nn.Linear(config.d_model, config.ffn_dim, bias=config.bias),
  406. ACT2CLS[config.activation_function](),
  407. nn.Dropout(config.ff_dropout) if config.ff_dropout > 0 else nn.Identity(),
  408. nn.Linear(config.ffn_dim, config.d_model, bias=config.bias),
  409. )
  410. # Add & Norm of sublayer 3
  411. self.dropout_path3 = nn.Dropout(config.path_dropout) if config.path_dropout > 0 else nn.Identity()
  412. if config.norm_type == "batchnorm":
  413. self.norm_sublayer3 = PatchTSTBatchNorm(config)
  414. elif config.norm_type == "layernorm":
  415. self.norm_sublayer3 = nn.LayerNorm(config.d_model, eps=config.norm_eps)
  416. else:
  417. raise ValueError(f"{config.norm_type} is not a supported norm layer type.")
  418. self.pre_norm = config.pre_norm
  419. def forward(self, hidden_state: torch.Tensor, output_attentions: Optional[bool] = None):
  420. """
  421. Parameters:
  422. hidden_state (`torch.Tensor` of shape `(batch_size, num_channels, sequence_length, d_model)`, *required*):
  423. Past values of the time series
  424. output_attentions (`bool`, *optional*):
  425. Whether or not to return the output attention of all layers
  426. Return:
  427. `torch.Tensor` of shape `(batch_size, num_channels, sequence_length, d_model)`
  428. """
  429. batch_size, num_input_channels, sequence_length, d_model = hidden_state.shape
  430. # First sublayer: attention across time
  431. # hidden_states: [(bs*num_channels) x sequence_length x d_model]
  432. hidden_state = hidden_state.view(batch_size * num_input_channels, sequence_length, d_model)
  433. if self.pre_norm:
  434. ## Norm and Multi-Head attention and Add residual connection
  435. attn_output, attn_weights, _ = self.self_attn(
  436. hidden_states=self.norm_sublayer1(hidden_state), output_attentions=output_attentions
  437. )
  438. # Add: residual connection with residual dropout
  439. hidden_state = hidden_state + self.dropout_path1(attn_output)
  440. else:
  441. ## Multi-Head attention and Add residual connection and Norm - Standard Transformer from BERT
  442. attn_output, attn_weights, _ = self.self_attn(
  443. hidden_states=hidden_state, output_attentions=output_attentions
  444. )
  445. # hidden_states: [(bs*num_channels) x sequence_length x d_model]
  446. hidden_state = self.norm_sublayer1(hidden_state + self.dropout_path1(attn_output))
  447. # hidden_state: [bs x num_channels x sequence_length x d_model]
  448. hidden_state = hidden_state.reshape(batch_size, num_input_channels, sequence_length, d_model)
  449. # second sublayer: attention across variable at any given time
  450. if self.channel_attention:
  451. # hidden_state: [bs x sequence_length x num_channels x d_model]
  452. hidden_state = hidden_state.transpose(2, 1).contiguous()
  453. # hidden_state: [(bs*sequence_length) x num_channels x d_model]
  454. hidden_state = hidden_state.view(batch_size * sequence_length, num_input_channels, d_model)
  455. if self.pre_norm:
  456. ## Norm and Multi-Head attention and Add residual connection
  457. attn_output, channel_attn_weights, _ = self.self_attn(
  458. hidden_states=self.norm_sublayer2(hidden_state), output_attentions=output_attentions
  459. )
  460. # Add: residual connection with residual dropout
  461. hidden_state = hidden_state + self.dropout_path2(attn_output)
  462. else:
  463. ## Multi-Head attention and Add residual connection and Norm
  464. attn_output, channel_attn_weights, _ = self.self_attn(
  465. hidden_states=hidden_state, output_attentions=output_attentions
  466. )
  467. # hidden_states: [(bs*sequence_length) x num_channels x d_model]
  468. hidden_state = self.norm_sublayer2(hidden_state + self.dropout_path2(attn_output))
  469. # Reshape hidden state
  470. # hidden_state: [bs x sequence_length x num_channels x d_model]
  471. hidden_state = hidden_state.reshape(batch_size, sequence_length, num_input_channels, d_model)
  472. # hidden_state: [bs x num_channels x sequence_length x d_model]
  473. hidden_state = hidden_state.transpose(1, 2).contiguous()
  474. # Third sublayer: mixing across hidden
  475. # hidden_state: [(batch_size*num_channels) x sequence_length x d_model]
  476. hidden_state = hidden_state.view(batch_size * num_input_channels, sequence_length, d_model)
  477. if self.pre_norm:
  478. ## Norm and Position-wise Feed-Forward and Add residual connection
  479. # Add: residual connection with residual dropout
  480. hidden_state = hidden_state + self.dropout_path3(self.ff(self.norm_sublayer3(hidden_state)))
  481. else:
  482. ## Position-wise Feed-Forward and Add residual connection and Norm
  483. # Add: residual connection with residual dropout
  484. hidden_state = self.norm_sublayer3(hidden_state + self.dropout_path3(self.ff(hidden_state)))
  485. # [bs x num_channels x sequence_length x d_model]
  486. hidden_state = hidden_state.reshape(batch_size, num_input_channels, sequence_length, d_model)
  487. outputs = (hidden_state,)
  488. if output_attentions:
  489. outputs += (attn_weights, channel_attn_weights) if self.channel_attention else (attn_weights,)
  490. return outputs
  491. class PatchTSTPreTrainedModel(PreTrainedModel):
  492. config_class = PatchTSTConfig
  493. base_model_prefix = "model"
  494. main_input_name = "past_values"
  495. supports_gradient_checkpointing = False
  496. def _init_weights(self, module):
  497. """
  498. Initialize weights
  499. """
  500. if isinstance(module, PatchTSTPositionalEncoding):
  501. # initialize cls_token
  502. if self.config.use_cls_token:
  503. nn.init.normal_(module.cls_token, std=0.02)
  504. # initialize positional encoding
  505. if self.config.positional_encoding_type == "random":
  506. nn.init.normal_(module.position_enc, mean=0.0, std=0.1)
  507. elif isinstance(module, nn.LayerNorm):
  508. module.bias.data.zero_()
  509. module.weight.data.fill_(1.0)
  510. elif isinstance(module, PatchTSTBatchNorm):
  511. module.batchnorm.bias.data.zero_()
  512. module.batchnorm.weight.data.fill_(1.0)
  513. elif isinstance(module, (nn.Linear, nn.Conv1d)):
  514. module.weight.data.normal_(mean=0.0, std=self.config.init_std)
  515. if module.bias is not None:
  516. module.bias.data.zero_()
  517. def _set_gradient_checkpointing(self, module, value=False):
  518. if isinstance(module, (PatchTSTEncoder)):
  519. module.gradient_checkpointing = value
  520. class PatchTSTEmbedding(nn.Module):
  521. def __init__(self, config: PatchTSTConfig):
  522. super().__init__()
  523. self.num_input_channels = config.num_input_channels
  524. self.share_embedding = config.share_embedding
  525. # Input encoding: projection of feature vectors onto a d-dim vector space
  526. if self.share_embedding:
  527. self.input_embedding = nn.Linear(config.patch_length, config.d_model)
  528. else:
  529. self.input_embedding = nn.ModuleList()
  530. for _ in range(config.num_input_channels):
  531. self.input_embedding.append(nn.Linear(config.patch_length, config.d_model))
  532. def forward(self, patch_input: torch.Tensor):
  533. """
  534. Parameters:
  535. patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*):
  536. Patch input for embedding
  537. return:
  538. `torch.Tensor` of shape `(batch_size, num_channels, num_patches, d_model)`
  539. """
  540. # Input encoding
  541. num_input_channels = patch_input.shape[1]
  542. if num_input_channels != self.num_input_channels:
  543. raise ValueError(
  544. f"The defined number of input channels ({self.num_input_channels}) in the config "
  545. f"has to be the same as the number of channels in the batch input ({num_input_channels})"
  546. )
  547. if self.share_embedding:
  548. embeddings = self.input_embedding(patch_input) # x: [bs x num_channels x num_patches x d_model]
  549. else:
  550. embeddings = [self.input_embedding[i](patch_input[:, i, :, :]) for i in range(num_input_channels)]
  551. embeddings = torch.stack(embeddings, dim=1)
  552. return embeddings
  553. class PatchTSTPositionalEncoding(nn.Module):
  554. """
  555. Class for positional encoding
  556. """
  557. def __init__(self, config: PatchTSTConfig, num_patches: int):
  558. super().__init__()
  559. self.use_cls_token = config.use_cls_token
  560. self.num_input_channels = config.num_input_channels
  561. if config.use_cls_token:
  562. # cls_token: [1 x num_input_channels x 1 x d_model]
  563. self.cls_token = nn.Parameter(torch.zeros(1, 1, 1, config.d_model))
  564. num_patches += 1
  565. # postional encoding: [num_patches x d_model]
  566. self.position_enc = self._init_pe(config, num_patches)
  567. # Positional dropout
  568. self.positional_dropout = (
  569. nn.Dropout(config.positional_dropout) if config.positional_dropout > 0 else nn.Identity()
  570. )
  571. @staticmethod
  572. def _init_pe(config: PatchTSTConfig, num_patches: int) -> nn.Parameter:
  573. # Positional encoding
  574. if config.positional_encoding_type == "random":
  575. position_enc = nn.Parameter(torch.randn(num_patches, config.d_model), requires_grad=True)
  576. elif config.positional_encoding_type == "sincos":
  577. position_enc = torch.zeros(num_patches, config.d_model)
  578. position = torch.arange(0, num_patches).unsqueeze(1)
  579. div_term = torch.exp(torch.arange(0, config.d_model, 2) * -(math.log(10000.0) / config.d_model))
  580. position_enc[:, 0::2] = torch.sin(position * div_term)
  581. position_enc[:, 1::2] = torch.cos(position * div_term)
  582. position_enc = position_enc - position_enc.mean()
  583. position_enc = position_enc / (position_enc.std() * 10)
  584. position_enc = nn.Parameter(position_enc, requires_grad=False)
  585. else:
  586. raise ValueError(
  587. f"{config.positional_encoding_type} is not a valid positional encoder. Available types are 'random' and 'sincos'."
  588. )
  589. return position_enc
  590. def forward(self, patch_input: torch.Tensor):
  591. if self.use_cls_token:
  592. # patch_input: [bs x num_channels x num_patches x d_model]
  593. patch_input = self.positional_dropout(patch_input + self.position_enc[1:, :])
  594. # append cls token where cls_token: [1 x num_channels x 1 x d_model]
  595. cls_token = self.cls_token + self.position_enc[:1, :]
  596. # get the same copy of cls_token for all the samples in batch: [bs x num_channels x 1 x d_model]
  597. cls_tokens = cls_token.expand(patch_input.shape[0], self.num_input_channels, -1, -1)
  598. # hidden_state: [bs x num_channels x (num_patches+1) x d_model]
  599. hidden_state = torch.cat((cls_tokens, patch_input), dim=2)
  600. else:
  601. # hidden_state: [bs x num_channels x num_patches x d_model]
  602. hidden_state = self.positional_dropout(patch_input + self.position_enc)
  603. return hidden_state
  604. class PatchTSTEncoder(PatchTSTPreTrainedModel):
  605. """
  606. PatchTST Encoder
  607. """
  608. def __init__(self, config: PatchTSTConfig, num_patches: int):
  609. super().__init__(config)
  610. self.gradient_checkpointing = False
  611. # Input embedding: projection of feature vectors onto a d-dim vector space
  612. self.embedder = PatchTSTEmbedding(config)
  613. # Positional encoding
  614. self.positional_encoder = PatchTSTPositionalEncoding(config, num_patches)
  615. # Encoder
  616. self.layers = nn.ModuleList([PatchTSTEncoderLayer(config) for i in range(config.num_hidden_layers)])
  617. # Initialize weights and apply final processing
  618. self.post_init()
  619. def forward(
  620. self,
  621. patch_input: torch.Tensor,
  622. output_hidden_states: Optional[bool] = None,
  623. output_attentions: Optional[bool] = None,
  624. ) -> BaseModelOutput:
  625. """
  626. Parameters:
  627. patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*):
  628. Past values of the time series
  629. output_hidden_states (bool, optional): Indicates if hidden states should be outputted.
  630. output_attentions (bool, optional): Indicates if attentions should be outputted.
  631. return:
  632. `BaseModelOutput`
  633. """
  634. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  635. output_hidden_states = (
  636. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  637. )
  638. # Input embedding
  639. patch_input = self.embedder(patch_input)
  640. # Positional encoding
  641. hidden_state = self.positional_encoder(patch_input)
  642. encoder_states = () if output_hidden_states else None
  643. all_attentions = () if output_attentions else None
  644. for encoder_layer in self.layers:
  645. if output_hidden_states:
  646. encoder_states = encoder_states + (hidden_state,)
  647. layer_outputs = encoder_layer(hidden_state=hidden_state, output_attentions=output_attentions)
  648. # get hidden state. hidden_state shape is [bs x num_channels x num_patches x d_model]
  649. # or [bs x num_channels x (num_patches+1) x d_model] if use cls_token
  650. hidden_state = layer_outputs[0]
  651. # append attention matrix at each layer
  652. if output_attentions:
  653. all_attentions = all_attentions + (layer_outputs[1],)
  654. # return past_values, hidden_states
  655. return BaseModelOutput(last_hidden_state=hidden_state, hidden_states=encoder_states, attentions=all_attentions)
  656. PATCHTST_START_DOCSTRING = r"""
  657. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  658. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  659. etc.)
  660. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  661. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  662. and behavior.
  663. Parameters:
  664. config ([`PatchTSTConfig`]):
  665. Model configuration class with all the parameters of the model. Initializing with a config file does not
  666. load the weights associated with the model, only the configuration. Check out the
  667. [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  668. """
  669. @dataclass
  670. class PatchTSTModelOutput(ModelOutput):
  671. """
  672. Base class for model's outputs, with potential hidden states.
  673. Parameters:
  674. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, patch_length)`):
  675. Sequence of hidden-states at the output of the last layer of the model.
  676. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  677. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  678. one for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of
  679. the model at the output of each layer plus the optional initial embedding outputs.
  680. mask: (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches)`, *optional*)
  681. Bool masked tensor indicating which patches are masked
  682. loc: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*)
  683. Mean of the input data (batch_size, sequence_length, num_channels) over the sequence_length
  684. scale: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*)
  685. Std of the input data (batch_size, sequence_length, num_channels) over the sequence_length
  686. patch_input (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, patch_length)`):
  687. Patched input to the Transformer
  688. """
  689. last_hidden_state: torch.FloatTensor = None
  690. hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  691. attentions: Optional[Tuple[torch.FloatTensor]] = None
  692. mask: torch.FloatTensor = None
  693. loc: torch.FloatTensor = None
  694. scale: torch.FloatTensor = None
  695. patch_input: torch.FloatTensor = None
  696. @dataclass
  697. class PatchTSTForPretrainingOutput(ModelOutput):
  698. """
  699. Output type of [`PatchTSTForPretraining`].
  700. Parameters:
  701. loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
  702. MSE loss.
  703. prediction_outputs (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  704. Prediction outputs of the time series modeling heads.
  705. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  706. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  707. shape `(batch_size, sequence_length, hidden_size)`.
  708. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
  709. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  710. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  711. sequence_length)`.
  712. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  713. heads.
  714. """
  715. loss: Optional[torch.FloatTensor] = None
  716. prediction_output: torch.FloatTensor = None
  717. hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  718. attentions: Optional[Tuple[torch.FloatTensor]] = None
  719. @dataclass
  720. class PatchTSTForRegressionOutput(ModelOutput):
  721. """
  722. Output type of [`PatchTSTForRegression`].
  723. Parameters:
  724. loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
  725. MSE loss.
  726. regression_outputs (`torch.FloatTensor` of shape `(batch_size, num_targets)`):
  727. Regression outputs of the time series modeling heads.
  728. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  729. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  730. shape `(batch_size, sequence_length, hidden_size)`.
  731. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
  732. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  733. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  734. sequence_length)`.
  735. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  736. heads.
  737. """
  738. loss: Optional[torch.FloatTensor] = None
  739. regression_outputs: torch.FloatTensor = None
  740. hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  741. attentions: Optional[Tuple[torch.FloatTensor]] = None
  742. @dataclass
  743. class PatchTSTForPredictionOutput(ModelOutput):
  744. """
  745. Output type of [`PatchTSTForPrediction`].
  746. Parameters:
  747. loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
  748. MSE loss.
  749. prediction_outputs (`torch.FloatTensor` of shape `(batch_size, prediction_length, -1)`):
  750. Prediction outputs of the time series modeling heads.
  751. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  752. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  753. shape `(batch_size, sequence_length, hidden_size)`.
  754. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
  755. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  756. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  757. sequence_length)`.
  758. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  759. heads.
  760. loc: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*)
  761. Mean of the input data (batch_size, sequence_length, num_channels) over the sequence_length
  762. scale: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`, *optional*)
  763. Std of the input data (batch_size, sequence_length, num_channels) over the sequence_length
  764. """
  765. loss: Optional[torch.FloatTensor] = None
  766. prediction_outputs: torch.FloatTensor = None
  767. hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  768. attentions: Optional[Tuple[torch.FloatTensor]] = None
  769. loc: torch.FloatTensor = None
  770. scale: torch.FloatTensor = None
  771. @dataclass
  772. class PatchTSTForClassificationOutput(ModelOutput):
  773. """
  774. Output type of [`PatchTSTForClassification`].
  775. Parameters:
  776. loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
  777. Total loss as the sum of the masked language modeling loss and the next sequence prediction
  778. (classification) loss.
  779. prediction_logits (`torch.FloatTensor` of shape `(batch_size, num_targets)`):
  780. Prediction scores of the PatchTST modeling head (scores before SoftMax).
  781. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  782. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  783. shape `(batch_size, sequence_length, hidden_size)`.
  784. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
  785. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  786. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  787. sequence_length)`.
  788. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  789. heads.
  790. """
  791. loss: Optional[torch.FloatTensor] = None
  792. prediction_logits: torch.FloatTensor = None
  793. hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  794. attentions: Optional[Tuple[torch.FloatTensor]] = None
  795. @dataclass
  796. class SamplePatchTSTOutput(ModelOutput):
  797. """
  798. Base class for time series model's predictions outputs that contains the sampled values from the chosen
  799. distribution.
  800. Parameters:
  801. sequences `(batch_size, num_samples, prediction_length, num_targets)`):
  802. Sampled values from the chosen distribution.
  803. """
  804. sequences: torch.FloatTensor = None
  805. # Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.nll
  806. def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch.Tensor:
  807. """
  808. Computes the negative log likelihood loss from input distribution with respect to target.
  809. """
  810. return -input.log_prob(target)
  811. # Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.weighted_average
  812. def weighted_average(input_tensor: torch.Tensor, weights: Optional[torch.Tensor] = None, dim=None) -> torch.Tensor:
  813. """
  814. Computes the weighted average of a given tensor across a given `dim`, masking values associated with weight zero,
  815. meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`.
  816. Args:
  817. input_tensor (`torch.FloatTensor`):
  818. Input tensor, of which the average must be computed.
  819. weights (`torch.FloatTensor`, *optional*):
  820. Weights tensor, of the same shape as `input_tensor`.
  821. dim (`int`, *optional*):
  822. The dim along which to average `input_tensor`.
  823. Returns:
  824. `torch.FloatTensor`: The tensor with values averaged along the specified `dim`.
  825. """
  826. if weights is not None:
  827. weighted_tensor = torch.where(weights != 0, input_tensor * weights, torch.zeros_like(input_tensor))
  828. sum_weights = torch.clamp(weights.sum(dim=dim) if dim else weights.sum(), min=1.0)
  829. return (weighted_tensor.sum(dim=dim) if dim else weighted_tensor.sum()) / sum_weights
  830. else:
  831. return input_tensor.mean(dim=dim)
  832. # Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesStdScaler with TimeSeriesTransformer->PatchTST,TimeSeries->PatchTST
  833. class PatchTSTStdScaler(nn.Module):
  834. """
  835. Standardize features by calculating the mean and scaling along the first dimension, and then normalizes it by
  836. subtracting from the mean and dividing by the standard deviation.
  837. """
  838. def __init__(self, config: PatchTSTConfig):
  839. super().__init__()
  840. self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
  841. self.keepdim = config.keepdim if hasattr(config, "keepdim") else True
  842. self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-5
  843. def forward(
  844. self, data: torch.Tensor, observed_indicator: torch.Tensor
  845. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  846. """
  847. Parameters:
  848. data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
  849. input for Batch norm calculation
  850. observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
  851. Calculating the scale on the observed indicator.
  852. Returns:
  853. tuple of `torch.Tensor` of shapes
  854. (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
  855. `(batch_size, 1, num_input_channels)`)
  856. """
  857. denominator = observed_indicator.sum(self.dim, keepdim=self.keepdim)
  858. denominator = denominator.clamp_min(1.0)
  859. loc = (data * observed_indicator).sum(self.dim, keepdim=self.keepdim) / denominator
  860. variance = (((data - loc) * observed_indicator) ** 2).sum(self.dim, keepdim=self.keepdim) / denominator
  861. scale = torch.sqrt(variance + self.minimum_scale)
  862. return (data - loc) / scale, loc, scale
  863. # Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesMeanScaler with TimeSeriesTransformer->PatchTST,TimeSeries->PatchTST
  864. class PatchTSTMeanScaler(nn.Module):
  865. """
  866. Computes a scaling factor as the weighted average absolute value along the first dimension, and scales the data
  867. accordingly.
  868. """
  869. def __init__(self, config: PatchTSTConfig):
  870. super().__init__()
  871. self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
  872. self.keepdim = config.keepdim if hasattr(config, "keepdim") else True
  873. self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-10
  874. self.default_scale = config.default_scale if hasattr(config, "default_scale") else None
  875. def forward(
  876. self, data: torch.Tensor, observed_indicator: torch.Tensor
  877. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  878. """
  879. Parameters:
  880. data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
  881. input for Batch norm calculation
  882. observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
  883. Calculating the scale on the observed indicator.
  884. Returns:
  885. tuple of `torch.Tensor` of shapes
  886. (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
  887. `(batch_size, 1, num_input_channels)`)
  888. """
  889. ts_sum = (data * observed_indicator).abs().sum(self.dim, keepdim=True)
  890. num_observed = observed_indicator.sum(self.dim, keepdim=True)
  891. scale = ts_sum / torch.clamp(num_observed, min=1)
  892. # If `default_scale` is provided, we use it, otherwise we use the scale
  893. # of the batch.
  894. if self.default_scale is None:
  895. batch_sum = ts_sum.sum(dim=0)
  896. batch_observations = torch.clamp(num_observed.sum(0), min=1)
  897. default_scale = torch.squeeze(batch_sum / batch_observations)
  898. else:
  899. default_scale = self.default_scale * torch.ones_like(scale)
  900. # apply default scale where there are no observations
  901. scale = torch.where(num_observed > 0, scale, default_scale)
  902. # ensure the scale is at least `self.minimum_scale`
  903. scale = torch.clamp(scale, min=self.minimum_scale)
  904. scaled_data = data / scale
  905. if not self.keepdim:
  906. scale = scale.squeeze(dim=self.dim)
  907. return scaled_data, torch.zeros_like(scale), scale
  908. # Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesNOPScaler with TimeSeriesTransformer->PatchTST,TimeSeries->PatchTST
  909. class PatchTSTNOPScaler(nn.Module):
  910. """
  911. Assigns a scaling factor equal to 1 along the first dimension, and therefore applies no scaling to the input data.
  912. """
  913. def __init__(self, config: PatchTSTConfig):
  914. super().__init__()
  915. self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
  916. self.keepdim = config.keepdim if hasattr(config, "keepdim") else True
  917. def forward(
  918. self, data: torch.Tensor, observed_indicator: torch.Tensor = None
  919. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  920. """
  921. Parameters:
  922. data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
  923. input for Batch norm calculation
  924. Returns:
  925. tuple of `torch.Tensor` of shapes
  926. (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
  927. `(batch_size, 1, num_input_channels)`)
  928. """
  929. scale = torch.ones_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim)
  930. loc = torch.zeros_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim)
  931. return data, loc, scale
  932. class PatchTSTScaler(nn.Module):
  933. def __init__(self, config: PatchTSTConfig):
  934. super().__init__()
  935. if config.scaling == "mean" or config.scaling is True:
  936. self.scaler = PatchTSTMeanScaler(config)
  937. elif config.scaling == "std":
  938. self.scaler = PatchTSTStdScaler(config)
  939. else:
  940. self.scaler = PatchTSTNOPScaler(config)
  941. def forward(
  942. self, data: torch.Tensor, observed_indicator: torch.Tensor
  943. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  944. """
  945. Parameters:
  946. data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
  947. Input for scaler calculation
  948. observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
  949. Calculating the scale on the observed indicator.
  950. Returns:
  951. tuple of `torch.Tensor` of shapes
  952. (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
  953. `(batch_size, 1, um_input_channels)`)
  954. """
  955. data, loc, scale = self.scaler(data, observed_indicator)
  956. return data, loc, scale
  957. @add_start_docstrings(
  958. "The bare PatchTST Model outputting raw hidden-states without any specific head.",
  959. PATCHTST_START_DOCSTRING,
  960. )
  961. class PatchTSTModel(PatchTSTPreTrainedModel):
  962. def __init__(self, config: PatchTSTConfig):
  963. super().__init__(config)
  964. self.scaler = PatchTSTScaler(config)
  965. self.patchifier = PatchTSTPatchify(config)
  966. self.do_mask_input = config.do_mask_input
  967. # get num_patches information from PatchTSTPatchify
  968. num_patches = self.patchifier.num_patches
  969. if self.do_mask_input:
  970. self.masking = PatchTSTMasking(config)
  971. else:
  972. self.masking = nn.Identity()
  973. self.encoder = PatchTSTEncoder(config, num_patches=num_patches)
  974. # Initialize weights and apply final processing
  975. self.post_init()
  976. def forward(
  977. self,
  978. past_values: torch.Tensor,
  979. past_observed_mask: Optional[torch.Tensor] = None,
  980. future_values: Optional[torch.Tensor] = None,
  981. output_hidden_states: Optional[bool] = None,
  982. output_attentions: Optional[bool] = None,
  983. return_dict: Optional[bool] = None,
  984. ) -> Union[Tuple, PatchTSTModelOutput]:
  985. r"""
  986. Parameters:
  987. past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*):
  988. Input sequence to the model
  989. past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
  990. Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
  991. in `[0, 1]`:
  992. - 1 for values that are **observed**,
  993. - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
  994. future_values (`torch.BoolTensor` of shape `(batch_size, prediction_length, num_input_channels)`, *optional*):
  995. Future target values associated with the `past_values`
  996. output_hidden_states (`bool`, *optional*):
  997. Whether or not to return the hidden states of all layers
  998. output_attentions (`bool`, *optional*):
  999. Whether or not to return the output attention of all layers
  1000. return_dict (`bool`, *optional*):
  1001. Whether or not to return a `ModelOutput` instead of a plain tuple.
  1002. Returns:
  1003. `PatchTSTModelOutput` or tuple of `torch.Tensor` (if `return_dict`=False or `config.return_dict`=False)
  1004. Examples:
  1005. ```python
  1006. >>> from huggingface_hub import hf_hub_download
  1007. >>> import torch
  1008. >>> from transformers import PatchTSTModel
  1009. >>> file = hf_hub_download(
  1010. ... repo_id="hf-internal-testing/etth1-hourly-batch", filename="train-batch.pt", repo_type="dataset"
  1011. ... )
  1012. >>> batch = torch.load(file)
  1013. >>> model = PatchTSTModel.from_pretrained("namctin/patchtst_etth1_pretrain")
  1014. >>> # during training, one provides both past and future values
  1015. >>> outputs = model(
  1016. ... past_values=batch["past_values"],
  1017. ... future_values=batch["future_values"],
  1018. ... )
  1019. >>> last_hidden_state = outputs.last_hidden_state
  1020. ```"""
  1021. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1022. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1023. output_hidden_states = (
  1024. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1025. )
  1026. if past_observed_mask is None:
  1027. past_observed_mask = torch.ones_like(past_values)
  1028. # x: tensor [bs x sequence_length x num_input_channels]
  1029. scaled_past_values, loc, scale = self.scaler(past_values, past_observed_mask)
  1030. # patched_values: [bs x num_input_channels x num_patches x patch_length] for pretrain
  1031. patched_values = self.patchifier(scaled_past_values)
  1032. if self.do_mask_input:
  1033. masked_values, mask = self.masking(patched_values)
  1034. else:
  1035. masked_values, mask = self.masking(patched_values), None
  1036. encoder_output = self.encoder(
  1037. patch_input=masked_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions
  1038. )
  1039. if not return_dict:
  1040. outputs = (encoder_output.last_hidden_state, encoder_output.hidden_states, encoder_output.attentions)
  1041. outputs = outputs + (mask, loc, scale, patched_values)
  1042. return tuple(v for v in outputs if v is not None)
  1043. return PatchTSTModelOutput(
  1044. last_hidden_state=encoder_output.last_hidden_state,
  1045. hidden_states=encoder_output.hidden_states,
  1046. attentions=encoder_output.attentions,
  1047. mask=mask,
  1048. loc=loc,
  1049. scale=scale,
  1050. patch_input=patched_values,
  1051. )
  1052. class PatchTSTMaskPretrainHead(nn.Module):
  1053. """
  1054. Pretraining head for mask modelling
  1055. """
  1056. def __init__(self, config: PatchTSTConfig):
  1057. super().__init__()
  1058. self.dropout = nn.Dropout(config.head_dropout) if config.head_dropout > 0 else nn.Identity()
  1059. self.linear = nn.Linear(config.d_model, config.patch_length)
  1060. self.use_cls_token = config.use_cls_token
  1061. def forward(self, embedding: torch.Tensor) -> torch.Tensor:
  1062. """
  1063. Parameters:
  1064. embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or
  1065. `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*):
  1066. Embedding from the model
  1067. Returns:
  1068. `torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or
  1069. `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True
  1070. """
  1071. embedding = self.linear(self.dropout(embedding)) # [bs x num_channels x num_patches x patch_length]
  1072. if self.use_cls_token:
  1073. embedding = embedding[:, :, 1:, :] # remove the first cls token
  1074. return embedding
  1075. @add_start_docstrings(
  1076. "The PatchTST for pretrain model.",
  1077. PATCHTST_START_DOCSTRING,
  1078. )
  1079. class PatchTSTForPretraining(PatchTSTPreTrainedModel):
  1080. def __init__(self, config: PatchTSTConfig):
  1081. super().__init__(config)
  1082. config.do_mask_input = True
  1083. self.model = PatchTSTModel(config=config)
  1084. self.head = PatchTSTMaskPretrainHead(config)
  1085. # Initialize weights and apply final processing
  1086. self.post_init()
  1087. def forward(
  1088. self,
  1089. past_values: torch.Tensor,
  1090. past_observed_mask: Optional[torch.Tensor] = None,
  1091. output_hidden_states: Optional[bool] = None,
  1092. output_attentions: Optional[bool] = None,
  1093. return_dict: Optional[bool] = None,
  1094. ) -> Union[Tuple, PatchTSTForPretrainingOutput]:
  1095. r"""
  1096. Parameters:
  1097. past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*):
  1098. Input sequence to the model
  1099. past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
  1100. Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
  1101. in `[0, 1]`:
  1102. - 1 for values that are **observed**,
  1103. - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
  1104. output_hidden_states (`bool`, *optional*):
  1105. Whether or not to return the hidden states of all layers
  1106. output_attentions (`bool`, *optional*):
  1107. Whether or not to return the output attention of all layers
  1108. return_dict (`bool`, *optional*): Whether or not to return a `ModelOutput` instead of a plain tuple.
  1109. Returns:
  1110. `PatchTSTForPretrainingOutput` or tuple of `torch.Tensor` (if `return_dict`=False or
  1111. `config.return_dict`=False)
  1112. Examples:
  1113. ```python
  1114. >>> from huggingface_hub import hf_hub_download
  1115. >>> import torch
  1116. >>> from transformers import PatchTSTConfig, PatchTSTForPretraining
  1117. >>> file = hf_hub_download(
  1118. ... repo_id="hf-internal-testing/etth1-hourly-batch", filename="train-batch.pt", repo_type="dataset"
  1119. ... )
  1120. >>> batch = torch.load(file)
  1121. >>> # Config for random mask pretraining
  1122. >>> config = PatchTSTConfig(
  1123. ... num_input_channels=7,
  1124. ... context_length=512,
  1125. ... patch_length=12,
  1126. ... stride=12,
  1127. ... mask_type='random',
  1128. ... random_mask_ratio=0.4,
  1129. ... use_cls_token=True,
  1130. ... )
  1131. >>> # Config for forecast mask pretraining
  1132. >>> config = PatchTSTConfig(
  1133. ... num_input_channels=7,
  1134. ... context_length=512,
  1135. ... patch_length=12,
  1136. ... stride=12,
  1137. ... mask_type='forecast',
  1138. ... num_forecast_mask_patches=5,
  1139. ... use_cls_token=True,
  1140. ... )
  1141. >>> model = PatchTSTForPretraining(config)
  1142. >>> # during training, one provides both past and future values
  1143. >>> outputs = model(past_values=batch["past_values"])
  1144. >>> loss = outputs.loss
  1145. >>> loss.backward()
  1146. ```"""
  1147. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1148. # past_values: [bs x num_channels x num_patches x d_model] or
  1149. # [bs x num_channels x (num_patches+1) x d_model] if use cls_token
  1150. model_output = self.model(
  1151. past_values=past_values,
  1152. past_observed_mask=past_observed_mask,
  1153. output_hidden_states=output_hidden_states,
  1154. output_attentions=output_attentions,
  1155. return_dict=True,
  1156. )
  1157. # last_hidden_state: [bs x num_channels x num_patches x patch_length] or
  1158. # [bs x num_channels x (num_patches+1) x patch_length] if use cls_token
  1159. x_hat = self.head(model_output.last_hidden_state)
  1160. # calculate masked_loss
  1161. loss = nn.MSELoss(reduction="none")
  1162. loss_val = loss(x_hat, model_output.patch_input)
  1163. masked_loss = (loss_val.mean(dim=-1) * model_output.mask).sum() / (model_output.mask.sum() + 1e-10)
  1164. encoder_states = model_output.hidden_states
  1165. if not return_dict:
  1166. outputs = (x_hat,) + model_output[1:-4]
  1167. outputs = (masked_loss,) + outputs if masked_loss is not None else outputs
  1168. return outputs
  1169. return PatchTSTForPretrainingOutput(
  1170. loss=masked_loss, prediction_output=x_hat, hidden_states=encoder_states, attentions=model_output.attentions
  1171. )
  1172. class PatchTSTClassificationHead(nn.Module):
  1173. def __init__(self, config: PatchTSTConfig):
  1174. super().__init__()
  1175. self.use_cls_token = config.use_cls_token
  1176. self.pooling_type = config.pooling_type
  1177. self.flatten = nn.Flatten(start_dim=1)
  1178. self.dropout = nn.Dropout(config.head_dropout) if config.head_dropout > 0 else nn.Identity()
  1179. self.linear = nn.Linear(config.num_input_channels * config.d_model, config.num_targets)
  1180. def forward(self, embedding: torch.Tensor):
  1181. """
  1182. Parameters:
  1183. embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or
  1184. `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*):
  1185. Embedding from the model
  1186. Returns:
  1187. `torch.Tensor` of shape `(bs, num_targets)`
  1188. """
  1189. if self.use_cls_token:
  1190. # use the first output token, pooled_embedding: bs x num_channels x d_model
  1191. pooled_embedding = embedding[:, :, 0, :]
  1192. elif self.pooling_type == "mean":
  1193. # pooled_embedding: [bs x num_channels x d_model]
  1194. pooled_embedding = embedding.mean(dim=2)
  1195. elif self.pooling_type == "max":
  1196. # pooled_embedding: [bs x num_channels x d_model]
  1197. pooled_embedding = embedding.max(dim=2).values
  1198. else:
  1199. raise ValueError(f"pooling operator {self.pooling_type} is not implemented yet")
  1200. # pooled_embedding: bs x num_channels * d_model
  1201. pooled_embedding = self.flatten(pooled_embedding)
  1202. # output: bs x n_classes
  1203. output = self.linear(self.dropout(pooled_embedding))
  1204. return output
  1205. @add_start_docstrings(
  1206. "The PatchTST for classification model.",
  1207. PATCHTST_START_DOCSTRING,
  1208. )
  1209. class PatchTSTForClassification(PatchTSTPreTrainedModel):
  1210. def __init__(self, config: PatchTSTConfig):
  1211. super().__init__(config)
  1212. # Turn off masking
  1213. if config.do_mask_input:
  1214. logger.warning("Setting `do_mask_input` parameter to False.")
  1215. config.do_mask_input = False
  1216. self.model = PatchTSTModel(config)
  1217. self.head = PatchTSTClassificationHead(config)
  1218. # Initialize weights and apply final processing
  1219. self.post_init()
  1220. def forward(
  1221. self,
  1222. past_values: torch.Tensor,
  1223. target_values: torch.Tensor = None,
  1224. past_observed_mask: Optional[bool] = None,
  1225. output_hidden_states: Optional[bool] = None,
  1226. output_attentions: Optional[bool] = None,
  1227. return_dict: Optional[bool] = None,
  1228. ) -> Union[tuple, PatchTSTForClassificationOutput]:
  1229. r"""
  1230. Parameters:
  1231. past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*):
  1232. Input sequence to the model
  1233. target_values (`torch.Tensor`, *optional*):
  1234. Labels associates with the `past_values`
  1235. past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
  1236. Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
  1237. in `[0, 1]`:
  1238. - 1 for values that are **observed**,
  1239. - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
  1240. output_hidden_states (`bool`, *optional*):
  1241. Whether or not to return the hidden states of all layers
  1242. output_attentions (`bool`, *optional*):
  1243. Whether or not to return the output attention of all layers
  1244. return_dict (`bool`, *optional*):
  1245. Whether or not to return a `ModelOutput` instead of a plain tuple.
  1246. Returns:
  1247. `PatchTSTForClassificationOutput` or tuple of `torch.Tensor` (if `return_dict`=False or
  1248. `config.return_dict`=False)
  1249. Examples:
  1250. ```python
  1251. >>> from transformers import PatchTSTConfig, PatchTSTForClassification
  1252. >>> # classification task with two input channel2 and 3 classes
  1253. >>> config = PatchTSTConfig(
  1254. ... num_input_channels=2,
  1255. ... num_targets=3,
  1256. ... context_length=512,
  1257. ... patch_length=12,
  1258. ... stride=12,
  1259. ... use_cls_token=True,
  1260. ... )
  1261. >>> model = PatchTSTForClassification(config=config)
  1262. >>> # during inference, one only provides past values
  1263. >>> past_values = torch.randn(20, 512, 2)
  1264. >>> outputs = model(past_values=past_values)
  1265. >>> labels = outputs.prediction_logits
  1266. ```"""
  1267. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1268. model_output = self.model(
  1269. past_values=past_values,
  1270. past_observed_mask=past_observed_mask,
  1271. output_hidden_states=output_hidden_states,
  1272. output_attentions=output_attentions,
  1273. return_dict=True,
  1274. )
  1275. y_hat = self.head(model_output.last_hidden_state)
  1276. loss_val = None
  1277. if target_values is not None:
  1278. loss = nn.CrossEntropyLoss()
  1279. loss_val = loss(y_hat, target_values)
  1280. if not return_dict:
  1281. outputs = (y_hat,) + model_output[1:-3]
  1282. outputs = (loss_val,) + outputs if loss_val is not None else outputs
  1283. return outputs
  1284. return PatchTSTForClassificationOutput(
  1285. loss=loss_val,
  1286. prediction_logits=y_hat,
  1287. hidden_states=model_output.hidden_states,
  1288. attentions=model_output.attentions,
  1289. )
  1290. @add_start_docstrings(
  1291. "The PatchTST for regression Model.",
  1292. PATCHTST_START_DOCSTRING,
  1293. )
  1294. class PatchTSTPredictionHead(nn.Module):
  1295. def __init__(self, config: PatchTSTConfig, num_patches, distribution_output=None):
  1296. super().__init__()
  1297. self.share_projection = config.share_projection
  1298. self.num_input_channels = config.num_input_channels
  1299. self.use_cls_token = config.use_cls_token
  1300. self.pooling_type = config.pooling_type
  1301. if self.pooling_type or self.use_cls_token:
  1302. head_dim = config.d_model
  1303. else:
  1304. head_dim = config.d_model * num_patches
  1305. if not self.share_projection:
  1306. # if each channel has its own head
  1307. self.projections = nn.ModuleList()
  1308. self.dropouts = nn.ModuleList()
  1309. self.flattens = nn.ModuleList()
  1310. for i in range(self.num_input_channels):
  1311. self.flattens.append(nn.Flatten(start_dim=2))
  1312. if distribution_output is None:
  1313. # use linear head
  1314. self.projections.append(nn.Linear(head_dim, config.prediction_length))
  1315. else:
  1316. # use distribution head
  1317. self.projections.append(distribution_output.get_parameter_projection(head_dim))
  1318. self.dropouts.append(nn.Dropout(config.head_dropout) if config.head_dropout > 0 else nn.Identity())
  1319. else:
  1320. # all the channels share the same head
  1321. self.flatten = nn.Flatten(start_dim=2)
  1322. if distribution_output is None:
  1323. # use linear head
  1324. self.projection = nn.Linear(head_dim, config.prediction_length)
  1325. else:
  1326. # use distribution head
  1327. self.projection = distribution_output.get_parameter_projection(head_dim)
  1328. self.dropout = nn.Dropout(config.head_dropout) if config.head_dropout > 0 else nn.Identity()
  1329. def forward(self, embedding: torch.Tensor):
  1330. """
  1331. Parameters:
  1332. embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or
  1333. `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*):
  1334. Embedding from the model
  1335. Returns:
  1336. `torch.Tensor` of shape `(bs, forecast_len, num_channels)`
  1337. """
  1338. if self.use_cls_token:
  1339. # pooled_embedding: [bs x num_channels x d_model]
  1340. pooled_embedding = embedding[:, :, 0, :]
  1341. else:
  1342. if self.pooling_type == "mean":
  1343. # pooled_embedding: [bs x num_channels x d_model]
  1344. pooled_embedding = embedding.mean(dim=2)
  1345. elif self.pooling_type == "max":
  1346. # pooled_embedding: [bs x num_channels x d_model]
  1347. pooled_embedding = embedding.max(dim=2).values
  1348. else:
  1349. # pooled_embedding: [bs x num_channels x num_patches x d_model]
  1350. pooled_embedding = embedding
  1351. if not self.share_projection:
  1352. output = []
  1353. for i in range(self.num_input_channels):
  1354. # pooled_embedding: [bs x (d_model * num_patches)] or [bs x d_model)]
  1355. pooled_embedding = self.flattens[i](pooled_embedding[:, i, :])
  1356. pooled_embedding = self.dropouts[i](pooled_embedding)
  1357. # pooled_embedding: [bs x forecast_len]
  1358. # or tuple ([bs x forecast_len], [bs x forecast_len]) if using distribution head
  1359. pooled_embedding = self.projections[i](pooled_embedding)
  1360. output.append(pooled_embedding)
  1361. # output: [bs x num_channels x forecast_len]
  1362. output = torch.stack(output, dim=1)
  1363. else:
  1364. # pooled_embedding: [bs x num_channels x (d_model * num_patches)] or [bs x num_channels x d_model)]
  1365. pooled_embedding = self.flatten(pooled_embedding)
  1366. pooled_embedding = self.dropout(pooled_embedding)
  1367. # output: [bs x num_channels x forecast_len] or
  1368. # tuple ([bs x num_channels x forecast_len], [bs x num_channels x forecast_len]) if using distribution head
  1369. output = self.projection(pooled_embedding)
  1370. if isinstance(output, tuple):
  1371. # output: ([bs x forecast_len x num_channels], [bs x forecast_len x num_channels])
  1372. output = tuple(z.transpose(2, 1) for z in output)
  1373. else:
  1374. output = output.transpose(2, 1) # [bs x forecast_len x num_channels]
  1375. return output
  1376. @add_start_docstrings(
  1377. "The PatchTST for prediction model.",
  1378. PATCHTST_START_DOCSTRING,
  1379. )
  1380. class PatchTSTForPrediction(PatchTSTPreTrainedModel):
  1381. def __init__(self, config: PatchTSTConfig):
  1382. super().__init__(config)
  1383. # Turn off masking
  1384. if config.do_mask_input:
  1385. logger.warning("Setting `do_mask_input` parameter to False.")
  1386. config.do_mask_input = False
  1387. self.model = PatchTSTModel(config)
  1388. if config.loss == "mse":
  1389. self.distribution_output = None
  1390. else:
  1391. if config.distribution_output == "student_t":
  1392. self.distribution_output = StudentTOutput(dim=config.prediction_length)
  1393. elif config.distribution_output == "normal":
  1394. self.distribution_output = NormalOutput(dim=config.prediction_length)
  1395. elif config.distribution_output == "negative_binomial":
  1396. self.distribution_output = NegativeBinomialOutput(dim=config.prediction_length)
  1397. else:
  1398. raise ValueError(f"Unknown distribution output {config.distribution_output}")
  1399. self.head = PatchTSTPredictionHead(
  1400. config, self.model.patchifier.num_patches, distribution_output=self.distribution_output
  1401. )
  1402. # Initialize weights and apply final processing
  1403. self.post_init()
  1404. def forward(
  1405. self,
  1406. past_values: torch.Tensor,
  1407. past_observed_mask: Optional[torch.Tensor] = None,
  1408. future_values: Optional[torch.Tensor] = None,
  1409. output_hidden_states: Optional[bool] = None,
  1410. output_attentions: Optional[bool] = None,
  1411. return_dict: Optional[bool] = None,
  1412. ) -> Union[Tuple, PatchTSTForPredictionOutput]:
  1413. r"""
  1414. Parameters:
  1415. past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*):
  1416. Input sequence to the model
  1417. past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
  1418. Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
  1419. in `[0, 1]`:
  1420. - 1 for values that are **observed**,
  1421. - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
  1422. future_values (`torch.Tensor` of shape `(bs, forecast_len, num_input_channels)`, *optional*):
  1423. Future target values associated with the `past_values`
  1424. output_hidden_states (`bool`, *optional*):
  1425. Whether or not to return the hidden states of all layers
  1426. output_attentions (`bool`, *optional*):
  1427. Whether or not to return the output attention of all layers
  1428. return_dict (`bool`, *optional*):
  1429. Whether or not to return a `ModelOutput` instead of a plain tuple.
  1430. Returns:
  1431. `PatchTSTForPredictionOutput` or tuple of `torch.Tensor` (if `return_dict`=False or
  1432. `config.return_dict`=False)
  1433. Examples:
  1434. ```python
  1435. >>> from huggingface_hub import hf_hub_download
  1436. >>> import torch
  1437. >>> from transformers import PatchTSTConfig, PatchTSTForPrediction
  1438. >>> file = hf_hub_download(
  1439. ... repo_id="hf-internal-testing/etth1-hourly-batch", filename="train-batch.pt", repo_type="dataset"
  1440. ... )
  1441. >>> batch = torch.load(file)
  1442. >>> # Prediction task with 7 input channels and prediction length is 96
  1443. >>> model = PatchTSTForPrediction.from_pretrained("namctin/patchtst_etth1_forecast")
  1444. >>> # during training, one provides both past and future values
  1445. >>> outputs = model(
  1446. ... past_values=batch["past_values"],
  1447. ... future_values=batch["future_values"],
  1448. ... )
  1449. >>> loss = outputs.loss
  1450. >>> loss.backward()
  1451. >>> # during inference, one only provides past values, the model outputs future values
  1452. >>> outputs = model(past_values=batch["past_values"])
  1453. >>> prediction_outputs = outputs.prediction_outputs
  1454. ```"""
  1455. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1456. # get model output
  1457. model_output = self.model(
  1458. past_values=past_values,
  1459. past_observed_mask=past_observed_mask,
  1460. output_hidden_states=output_hidden_states,
  1461. output_attentions=output_attentions,
  1462. return_dict=True,
  1463. )
  1464. # get output head
  1465. y_hat = self.head(model_output.last_hidden_state)
  1466. loss_val = None
  1467. if self.distribution_output:
  1468. y_hat_out = y_hat
  1469. else:
  1470. y_hat_out = y_hat * model_output.scale + model_output.loc
  1471. if future_values is not None:
  1472. if self.distribution_output:
  1473. distribution = self.distribution_output.distribution(
  1474. y_hat, loc=model_output.loc, scale=model_output.scale
  1475. )
  1476. loss_val = nll(distribution, future_values)
  1477. # take average of the loss
  1478. loss_val = weighted_average(loss_val)
  1479. else:
  1480. loss = nn.MSELoss(reduction="mean")
  1481. loss_val = loss(y_hat_out, future_values)
  1482. loc = model_output.loc
  1483. scale = model_output.scale
  1484. if not return_dict:
  1485. outputs = (y_hat_out,) + model_output[1:-1]
  1486. outputs = (loss_val,) + outputs if loss_val is not None else outputs
  1487. return outputs
  1488. return PatchTSTForPredictionOutput(
  1489. loss=loss_val,
  1490. prediction_outputs=y_hat_out,
  1491. hidden_states=model_output.hidden_states,
  1492. attentions=model_output.attentions,
  1493. loc=loc,
  1494. scale=scale,
  1495. )
  1496. def generate(
  1497. self,
  1498. past_values: torch.Tensor,
  1499. past_observed_mask: Optional[torch.Tensor] = None,
  1500. ) -> SamplePatchTSTOutput:
  1501. """
  1502. Generate sequences of sample predictions from a model with a probability distribution head.
  1503. Parameters:
  1504. past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
  1505. Past values of the time series that serves as context in order to predict the future.
  1506. past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
  1507. Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
  1508. in `[0, 1]`:
  1509. - 1 for values that are **observed**,
  1510. - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
  1511. Return:
  1512. [`SamplePatchTSTOutput`] where the outputs `sequences` tensor will have shape `(batch_size, number of
  1513. samples, prediction_length, 1)` or `(batch_size, number of samples, prediction_length, num_input_channels)`
  1514. for multivariate predictions.
  1515. """
  1516. # get number of samples
  1517. num_parallel_samples = self.config.num_parallel_samples
  1518. # get model output
  1519. outputs = self(
  1520. past_values=past_values,
  1521. future_values=None,
  1522. past_observed_mask=past_observed_mask,
  1523. output_hidden_states=False,
  1524. )
  1525. if self.distribution_output:
  1526. # get distribution
  1527. distribution = self.distribution_output.distribution(
  1528. outputs.prediction_outputs, loc=outputs.loc, scale=outputs.scale
  1529. )
  1530. # get samples: list of [bs x forecast_len x num_channels]
  1531. samples = [distribution.sample() for _ in range(num_parallel_samples)]
  1532. # samples: [bs x num_samples x forecast_len x num_channels]
  1533. samples = torch.stack(samples, dim=1)
  1534. else:
  1535. samples = outputs.prediction_outputs.unsqueeze(1)
  1536. return SamplePatchTSTOutput(sequences=samples)
  1537. class PatchTSTRegressionHead(nn.Module):
  1538. """
  1539. Regression head
  1540. """
  1541. def __init__(self, config: PatchTSTConfig, distribution_output=None):
  1542. super().__init__()
  1543. self.y_range = config.output_range
  1544. self.use_cls_token = config.use_cls_token
  1545. self.pooling_type = config.pooling_type
  1546. self.distribution_output = distribution_output
  1547. head_dim = config.num_input_channels * config.d_model
  1548. self.flatten = nn.Flatten(start_dim=1)
  1549. self.dropout = nn.Dropout(config.head_dropout) if config.head_dropout > 0 else nn.Identity()
  1550. if distribution_output is None:
  1551. self.projection = nn.Linear(head_dim, config.num_targets)
  1552. else:
  1553. self.projection = distribution_output.get_parameter_projection(head_dim)
  1554. def forward(self, embedding: torch.Tensor):
  1555. """
  1556. Parameters:
  1557. embedding (`torch.Tensor` of shape `(bs, num_channels, num_patches, d_model)` or
  1558. `(bs, num_channels, num_patches+1, d_model)` if `cls_token` is set to True, *required*):
  1559. Embedding from the model
  1560. Returns:
  1561. `torch.Tensor` of shape `(bs, output_dim)`
  1562. """
  1563. if self.use_cls_token:
  1564. # use the first output token, pooled_embedding: [bs x num_channels x d_model]
  1565. pooled_embedding = embedding[:, :, 0, :]
  1566. elif self.pooling_type == "mean":
  1567. # pooled_embedding: [bs x num_channels x d_model]
  1568. pooled_embedding = embedding.mean(dim=2)
  1569. elif self.pooling_type == "max":
  1570. # pooled_embedding: [bs x num_channels x d_model]
  1571. pooled_embedding = embedding.max(dim=2).values
  1572. else:
  1573. raise ValueError(f"pooling operator {self.pooling_type} is not implemented yet")
  1574. # flatten the input
  1575. # pooled_embedding: bs x (num_channels * d_model)
  1576. pooled_embedding = self.dropout(self.flatten(pooled_embedding))
  1577. # projection
  1578. # output: bs x output_dim or a tuple of this shape for distribution head
  1579. output = self.projection(pooled_embedding)
  1580. # apply sigmoid to bound the output if required
  1581. if (self.distribution_output is None) & (self.y_range is not None): # linear head
  1582. output = torch.sigmoid(output) * (self.y_range[1] - self.y_range[0]) + self.y_range[0]
  1583. return output
  1584. @add_start_docstrings(
  1585. "The PatchTST for regression model.",
  1586. PATCHTST_START_DOCSTRING,
  1587. )
  1588. class PatchTSTForRegression(PatchTSTPreTrainedModel):
  1589. def __init__(self, config: PatchTSTConfig):
  1590. super().__init__(config)
  1591. # Turn off masking
  1592. if config.do_mask_input:
  1593. logger.warning("Setting `do_mask_input` parameter to False.")
  1594. config.do_mask_input = False
  1595. self.model = PatchTSTModel(config)
  1596. if config.loss == "mse":
  1597. self.distribution_output = None
  1598. else:
  1599. if config.distribution_output == "student_t":
  1600. self.distribution_output = StudentTOutput(dim=config.num_targets)
  1601. elif config.distribution_output == "normal":
  1602. self.distribution_output = NormalOutput(dim=config.num_targets)
  1603. elif config.distribution_output == "negative_binomial":
  1604. self.distribution_output = NegativeBinomialOutput(dim=config.num_targets)
  1605. else:
  1606. raise ValueError(f"Unknown distribution output {config.distribution_output}")
  1607. self.head = PatchTSTRegressionHead(config, self.distribution_output)
  1608. # Initialize weights and apply final processing
  1609. self.post_init()
  1610. def forward(
  1611. self,
  1612. past_values: torch.Tensor,
  1613. target_values: torch.Tensor = None,
  1614. past_observed_mask: Optional[torch.Tensor] = None,
  1615. output_hidden_states: Optional[bool] = None,
  1616. output_attentions: Optional[bool] = None,
  1617. return_dict: Optional[bool] = None,
  1618. ) -> Union[tuple, PatchTSTForRegressionOutput]:
  1619. r"""
  1620. Parameters:
  1621. past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*):
  1622. Input sequence to the model
  1623. target_values (`torch.Tensor` of shape `(bs, num_input_channels)`):
  1624. Target values associates with the `past_values`
  1625. past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
  1626. Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
  1627. in `[0, 1]`:
  1628. - 1 for values that are **observed**,
  1629. - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
  1630. output_hidden_states (`bool`, *optional*):
  1631. Whether or not to return the hidden states of all layers
  1632. output_attentions (`bool`, *optional*):
  1633. Whether or not to return the output attention of all layers
  1634. return_dict (`bool`, *optional*):
  1635. Whether or not to return a `ModelOutput` instead of a plain tuple.
  1636. Returns:
  1637. `PatchTSTForRegressionOutput` or tuple of `torch.Tensor` (if `return_dict`=False or
  1638. `config.return_dict`=False)
  1639. Examples:
  1640. ```python
  1641. >>> from transformers import PatchTSTConfig, PatchTSTForRegression
  1642. >>> # Regression task with 6 input channels and regress 2 targets
  1643. >>> model = PatchTSTForRegression.from_pretrained("namctin/patchtst_etth1_regression")
  1644. >>> # during inference, one only provides past values, the model outputs future values
  1645. >>> past_values = torch.randn(20, 512, 6)
  1646. >>> outputs = model(past_values=past_values)
  1647. >>> regression_outputs = outputs.regression_outputs
  1648. ```"""
  1649. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1650. model_output = self.model(
  1651. past_values=past_values,
  1652. past_observed_mask=past_observed_mask,
  1653. output_hidden_states=output_hidden_states,
  1654. output_attentions=output_attentions,
  1655. return_dict=True,
  1656. )
  1657. # get output head. y_hat is of shape [bs x num_targets] or tuple of this shape
  1658. y_hat = self.head(model_output.last_hidden_state)
  1659. loss = None
  1660. if target_values is not None:
  1661. if self.distribution_output:
  1662. distribution = self.distribution_output.distribution(y_hat)
  1663. # y_hat should be a 2-tuple, each with dimension [bs, num_targets]
  1664. y_hat = tuple([item.view(-1, self.config.num_targets) for item in y_hat])
  1665. loss = nll(distribution, target_values)
  1666. # take average of the loss
  1667. loss = weighted_average(loss)
  1668. else:
  1669. loss = nn.MSELoss(reduction="mean")
  1670. loss = loss(y_hat, target_values)
  1671. if not return_dict:
  1672. # hidden_states, attentions, mask
  1673. outputs = (y_hat,) + model_output[1:-3]
  1674. outputs = (loss,) + outputs if loss is not None else outputs
  1675. return outputs
  1676. return PatchTSTForRegressionOutput(
  1677. loss=loss,
  1678. regression_outputs=y_hat,
  1679. hidden_states=model_output.hidden_states,
  1680. attentions=model_output.attentions,
  1681. )
  1682. def generate(
  1683. self,
  1684. past_values: torch.Tensor,
  1685. past_observed_mask: Optional[torch.Tensor] = None,
  1686. ) -> SamplePatchTSTOutput:
  1687. """
  1688. Generate sequences of sample predictions from a model with a probability distribution head.
  1689. Parameters:
  1690. past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
  1691. Past values of the time series that serves as context in order to predict the future.
  1692. past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
  1693. Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
  1694. in `[0, 1]`:
  1695. - 1 for values that are **observed**,
  1696. - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
  1697. Return:
  1698. [`SamplePatchTSTOutput`] where the outputs `sequences` tensor will have shape `(batch_size, number of
  1699. samples, num_targets)`.
  1700. """
  1701. # get number of samples
  1702. num_parallel_samples = self.config.num_parallel_samples
  1703. # get model output
  1704. outputs = self(
  1705. past_values=past_values,
  1706. target_values=None,
  1707. past_observed_mask=past_observed_mask,
  1708. output_hidden_states=False,
  1709. )
  1710. # get distribution
  1711. distribution = self.distribution_output.distribution(outputs.regression_outputs)
  1712. # get samples: list of [bs x num_targets]
  1713. samples = [distribution.sample() for _ in range(num_parallel_samples)]
  1714. # samples: [bs x num_samples x num_targets]
  1715. samples = torch.stack(samples, dim=1).view(-1, num_parallel_samples, self.config.num_targets)
  1716. return SamplePatchTSTOutput(sequences=samples)