modeling_patchtsmixer.py 86 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172
  1. # coding=utf-8
  2. # Copyright 2023 IBM and HuggingFace Inc. team. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch PatchTSMixer model."""
  16. import math
  17. from dataclasses import dataclass
  18. from typing import Optional, Tuple, Union
  19. import torch
  20. import torch.nn as nn
  21. from transformers.modeling_utils import PreTrainedModel
  22. from transformers.utils import ModelOutput
  23. from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput
  24. from ...utils import (
  25. add_start_docstrings,
  26. add_start_docstrings_to_model_forward,
  27. logging,
  28. replace_return_docstrings,
  29. )
  30. from .configuration_patchtsmixer import PatchTSMixerConfig
  31. logger = logging.get_logger(__name__)
  32. _CONFIG_FOR_DOC = "PatchTSMixerConfig"
  33. PATCHTSMIXER_START_DOCSTRING = r"""
  34. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  35. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  36. etc.)
  37. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  38. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  39. and behavior.
  40. Parameters:
  41. config ([`PatchTSMixerConfig`]):
  42. Model configuration class with all the parameters of the model. Initializing with a config file does not
  43. load the weights associated with the model, only the configuration. Check out the
  44. [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  45. mask_input (`bool`, *optional*, defaults to `False`):
  46. If True, Masking will be enabled. False otherwise.
  47. """
  48. PATCHTSMIXER_INPUTS_DOCSTRING = r"""
  49. Args:
  50. past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
  51. Context values of the time series. For a pretraining task, this denotes the input time series to predict
  52. the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly,
  53. for classification or regression tasks, it denotes the appropriate context values of the time series.
  54. For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series, it is
  55. greater than 1.
  56. output_hidden_states (`bool`, *optional*):
  57. Whether or not to return the hidden states of all layers.
  58. return_dict (`bool`, *optional*):
  59. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  60. """
  61. class PatchTSMixerGatedAttention(nn.Module):
  62. """
  63. Module that applies gated attention to input data.
  64. Args:
  65. in_size (`int`): The input size.
  66. out_size (`int`): The output size.
  67. """
  68. def __init__(self, in_size: int, out_size: int):
  69. super().__init__()
  70. self.attn_layer = nn.Linear(in_size, out_size)
  71. self.attn_softmax = nn.Softmax(dim=-1)
  72. def forward(self, inputs):
  73. attn_weight = self.attn_softmax(self.attn_layer(inputs))
  74. inputs = inputs * attn_weight
  75. return inputs
  76. # Copied from transformers.models.patchtst.modeling_patchtst.PatchTSTBatchNorm with PatchTST->PatchTSMixer
  77. class PatchTSMixerBatchNorm(nn.Module):
  78. """
  79. Compute batch normalization over the sequence length (time) dimension.
  80. """
  81. def __init__(self, config: PatchTSMixerConfig):
  82. super().__init__()
  83. self.batchnorm = nn.BatchNorm1d(config.d_model, eps=config.norm_eps)
  84. def forward(self, inputs: torch.Tensor):
  85. """
  86. Parameters:
  87. inputs (`torch.Tensor` of shape `(batch_size, sequence_length, d_model)`):
  88. input for Batch norm calculation
  89. Returns:
  90. `torch.Tensor` of shape `(batch_size, sequence_length, d_model)`
  91. """
  92. output = inputs.transpose(1, 2) # output: (batch_size, d_model, sequence_length)
  93. output = self.batchnorm(output)
  94. return output.transpose(1, 2)
  95. class PatchTSMixerPositionalEncoding(nn.Module):
  96. """
  97. Class for positional encoding
  98. """
  99. def __init__(self, config: PatchTSMixerConfig):
  100. super().__init__()
  101. # positional encoding: [num_patches x d_model]
  102. if config.use_positional_encoding:
  103. self.position_enc = self._init_pe(config)
  104. else:
  105. self.position_enc = nn.Parameter(torch.zeros(config.num_patches, config.d_model))
  106. @staticmethod
  107. def _init_pe(config: PatchTSMixerConfig) -> nn.Parameter:
  108. # Positional encoding
  109. if config.positional_encoding_type == "random":
  110. position_enc = nn.Parameter(torch.randn(config.num_patches, config.d_model), requires_grad=True)
  111. elif config.positional_encoding_type == "sincos":
  112. position_enc = torch.zeros(config.num_patches, config.d_model)
  113. position = torch.arange(0, config.num_patches).unsqueeze(1)
  114. div_term = torch.exp(torch.arange(0, config.d_model, 2) * -(math.log(10000.0) / config.d_model))
  115. position_enc[:, 0::2] = torch.sin(position * div_term)
  116. position_enc[:, 1::2] = torch.cos(position * div_term)
  117. position_enc = position_enc - position_enc.mean()
  118. position_enc = position_enc / (position_enc.std() * 10)
  119. position_enc = nn.Parameter(position_enc, requires_grad=False)
  120. else:
  121. raise ValueError(
  122. f"{config.positional_encoding_type} is not a valid positional encoder. Available types are 'random' and 'sincos'."
  123. )
  124. return position_enc
  125. def forward(self, patch_input: torch.Tensor):
  126. # hidden_state: [bs x num_channels x num_patches x d_model]
  127. hidden_state = patch_input + self.position_enc
  128. return hidden_state
  129. class PatchTSMixerNormLayer(nn.Module):
  130. """Normalization block
  131. Args:
  132. config (`PatchTSMixerConfig`):
  133. Configuration.
  134. """
  135. def __init__(self, config: PatchTSMixerConfig):
  136. super().__init__()
  137. self.norm_mlp = config.norm_mlp
  138. if "batch" in config.norm_mlp.lower():
  139. self.norm = PatchTSMixerBatchNorm(config)
  140. else:
  141. self.norm = nn.LayerNorm(config.d_model, eps=config.norm_eps)
  142. def forward(self, inputs: torch.Tensor):
  143. """
  144. Args:
  145. inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`):
  146. Input to the normalization layer.
  147. Returns:
  148. `torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`
  149. """
  150. if "batch" in self.norm_mlp.lower():
  151. # reshape the data
  152. inputs_reshaped = torch.reshape(
  153. inputs,
  154. (
  155. inputs.shape[0] * inputs.shape[1],
  156. inputs.shape[2],
  157. inputs.shape[3],
  158. ),
  159. ) # inputs_reshaped: [batch_size*num_channels, num_patches, d_model]
  160. # inputs_reshaped: [batch_size*num_channels, num_patches, d_model]
  161. inputs_reshaped = self.norm(inputs_reshaped)
  162. # put back data to the original shape
  163. inputs = torch.reshape(inputs_reshaped, inputs.shape)
  164. else:
  165. inputs = self.norm(inputs)
  166. return inputs
  167. class PatchTSMixerMLP(nn.Module):
  168. def __init__(self, in_features, out_features, config):
  169. super().__init__()
  170. num_hidden = in_features * config.expansion_factor
  171. self.fc1 = nn.Linear(in_features, num_hidden)
  172. self.dropout1 = nn.Dropout(config.dropout)
  173. self.fc2 = nn.Linear(num_hidden, out_features)
  174. self.dropout2 = nn.Dropout(config.dropout)
  175. def forward(self, inputs: torch.Tensor):
  176. """
  177. Args:
  178. inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`):
  179. Input to the MLP layer.
  180. Returns:
  181. `torch.Tensor` of the same shape as `inputs`
  182. """
  183. inputs = self.dropout1(nn.functional.gelu(self.fc1(inputs)))
  184. inputs = self.fc2(inputs)
  185. inputs = self.dropout2(inputs)
  186. return inputs
  187. class PatchTSMixerChannelFeatureMixerBlock(nn.Module):
  188. """This module mixes the features in the channel dimension.
  189. Args:
  190. config (`PatchTSMixerConfig`):
  191. Configuration.
  192. """
  193. def __init__(self, config: PatchTSMixerConfig):
  194. super().__init__()
  195. self.norm = PatchTSMixerNormLayer(config)
  196. self.gated_attn = config.gated_attn
  197. self.mlp = PatchTSMixerMLP(
  198. in_features=config.num_input_channels,
  199. out_features=config.num_input_channels,
  200. config=config,
  201. )
  202. if config.gated_attn:
  203. self.gating_block = PatchTSMixerGatedAttention(
  204. in_size=config.num_input_channels, out_size=config.num_input_channels
  205. )
  206. def forward(self, inputs: torch.Tensor):
  207. """
  208. Args:
  209. inputs (`torch.Tensor` of shape `((batch_size, num_channels, num_patches, d_model))`):
  210. input to the MLP layer
  211. Returns:
  212. `torch.Tensor` of the same shape as `inputs`
  213. """
  214. residual = inputs
  215. inputs = self.norm(inputs)
  216. inputs = inputs.permute(0, 3, 2, 1)
  217. if self.gated_attn:
  218. inputs = self.gating_block(inputs)
  219. inputs = self.mlp(inputs)
  220. inputs = inputs.permute(0, 3, 2, 1)
  221. out = inputs + residual
  222. return out
  223. # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->PatchTSMixer
  224. class PatchTSMixerAttention(nn.Module):
  225. """Multi-headed attention from 'Attention Is All You Need' paper"""
  226. def __init__(
  227. self,
  228. embed_dim: int,
  229. num_heads: int,
  230. dropout: float = 0.0,
  231. is_decoder: bool = False,
  232. bias: bool = True,
  233. is_causal: bool = False,
  234. config: Optional[PatchTSMixerConfig] = None,
  235. ):
  236. super().__init__()
  237. self.embed_dim = embed_dim
  238. self.num_heads = num_heads
  239. self.dropout = dropout
  240. self.head_dim = embed_dim // num_heads
  241. self.config = config
  242. if (self.head_dim * num_heads) != self.embed_dim:
  243. raise ValueError(
  244. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  245. f" and `num_heads`: {num_heads})."
  246. )
  247. self.scaling = self.head_dim**-0.5
  248. self.is_decoder = is_decoder
  249. self.is_causal = is_causal
  250. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  251. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  252. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  253. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  254. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  255. return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
  256. def forward(
  257. self,
  258. hidden_states: torch.Tensor,
  259. key_value_states: Optional[torch.Tensor] = None,
  260. past_key_value: Optional[Tuple[torch.Tensor]] = None,
  261. attention_mask: Optional[torch.Tensor] = None,
  262. layer_head_mask: Optional[torch.Tensor] = None,
  263. output_attentions: bool = False,
  264. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  265. """Input shape: Batch x Time x Channel"""
  266. # if key_value_states are provided this layer is used as a cross-attention layer
  267. # for the decoder
  268. is_cross_attention = key_value_states is not None
  269. bsz, tgt_len, _ = hidden_states.size()
  270. # get query proj
  271. query_states = self.q_proj(hidden_states) * self.scaling
  272. # get key, value proj
  273. # `past_key_value[0].shape[2] == key_value_states.shape[1]`
  274. # is checking that the `sequence_length` of the `past_key_value` is the same as
  275. # the provided `key_value_states` to support prefix tuning
  276. if (
  277. is_cross_attention
  278. and past_key_value is not None
  279. and past_key_value[0].shape[2] == key_value_states.shape[1]
  280. ):
  281. # reuse k,v, cross_attentions
  282. key_states = past_key_value[0]
  283. value_states = past_key_value[1]
  284. elif is_cross_attention:
  285. # cross_attentions
  286. key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
  287. value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
  288. elif past_key_value is not None:
  289. # reuse k, v, self_attention
  290. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  291. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  292. key_states = torch.cat([past_key_value[0], key_states], dim=2)
  293. value_states = torch.cat([past_key_value[1], value_states], dim=2)
  294. else:
  295. # self_attention
  296. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  297. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  298. if self.is_decoder:
  299. # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
  300. # Further calls to cross_attention layer can then reuse all cross-attention
  301. # key/value_states (first "if" case)
  302. # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
  303. # all previous decoder key/value_states. Further calls to uni-directional self-attention
  304. # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
  305. # if encoder bi-directional self-attention `past_key_value` is always `None`
  306. past_key_value = (key_states, value_states)
  307. proj_shape = (bsz * self.num_heads, -1, self.head_dim)
  308. query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
  309. key_states = key_states.reshape(*proj_shape)
  310. value_states = value_states.reshape(*proj_shape)
  311. src_len = key_states.size(1)
  312. attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
  313. if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
  314. raise ValueError(
  315. f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
  316. f" {attn_weights.size()}"
  317. )
  318. if attention_mask is not None:
  319. if attention_mask.size() != (bsz, 1, tgt_len, src_len):
  320. raise ValueError(
  321. f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
  322. )
  323. attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
  324. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  325. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  326. if layer_head_mask is not None:
  327. if layer_head_mask.size() != (self.num_heads,):
  328. raise ValueError(
  329. f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
  330. f" {layer_head_mask.size()}"
  331. )
  332. attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  333. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  334. if output_attentions:
  335. # this operation is a bit awkward, but it's required to
  336. # make sure that attn_weights keeps its gradient.
  337. # In order to do so, attn_weights have to be reshaped
  338. # twice and have to be reused in the following
  339. attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  340. attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
  341. else:
  342. attn_weights_reshaped = None
  343. attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  344. attn_output = torch.bmm(attn_probs, value_states)
  345. if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
  346. raise ValueError(
  347. f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
  348. f" {attn_output.size()}"
  349. )
  350. attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
  351. attn_output = attn_output.transpose(1, 2)
  352. # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
  353. # partitioned across GPUs when using tensor-parallelism.
  354. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
  355. attn_output = self.out_proj(attn_output)
  356. return attn_output, attn_weights_reshaped, past_key_value
  357. class PatchMixerBlock(nn.Module):
  358. """This module mixes the patch dimension.
  359. Args:
  360. config (`PatchTSMixerConfig`):
  361. Configuration.
  362. """
  363. def __init__(self, config: PatchTSMixerConfig):
  364. super().__init__()
  365. self.norm = PatchTSMixerNormLayer(config)
  366. self.self_attn = config.self_attn
  367. self.gated_attn = config.gated_attn
  368. self.mlp = PatchTSMixerMLP(
  369. in_features=config.num_patches,
  370. out_features=config.num_patches,
  371. config=config,
  372. )
  373. if config.gated_attn:
  374. self.gating_block = PatchTSMixerGatedAttention(in_size=config.num_patches, out_size=config.num_patches)
  375. if config.self_attn:
  376. self.self_attn_layer = PatchTSMixerAttention(
  377. embed_dim=config.d_model,
  378. num_heads=config.self_attn_heads,
  379. dropout=config.dropout,
  380. )
  381. self.norm_attn = PatchTSMixerNormLayer(config)
  382. def forward(self, hidden_state):
  383. """
  384. Args:
  385. hidden_state (`torch.Tensor`): Input tensor.
  386. Returns:
  387. `torch.Tensor`: Transformed tensor.
  388. """
  389. residual = hidden_state
  390. hidden_state = self.norm(hidden_state)
  391. if self.self_attn:
  392. batch_size, n_vars, num_patches, d_model = hidden_state.shape
  393. hidden_state_reshaped = hidden_state.reshape(batch_size * n_vars, num_patches, d_model)
  394. x_attn, _, _ = self.self_attn_layer(hidden_state_reshaped, output_attentions=False)
  395. x_attn = x_attn.reshape(batch_size, n_vars, num_patches, d_model)
  396. # Transpose so that num_patches is the last dimension
  397. hidden_state = hidden_state.transpose(2, 3)
  398. hidden_state = self.mlp(hidden_state)
  399. if self.gated_attn:
  400. hidden_state = self.gating_block(hidden_state)
  401. # Transpose back
  402. hidden_state = hidden_state.transpose(2, 3)
  403. if self.self_attn:
  404. hidden_state = self.norm_attn(hidden_state + x_attn)
  405. out = hidden_state + residual
  406. return out
  407. class FeatureMixerBlock(nn.Module):
  408. """This module mixes the hidden feature dimension.
  409. Args:
  410. config (`PatchTSMixerConfig`):
  411. Configuration.
  412. """
  413. def __init__(self, config: PatchTSMixerConfig):
  414. super().__init__()
  415. self.norm = PatchTSMixerNormLayer(config)
  416. self.gated_attn = config.gated_attn
  417. self.mlp = PatchTSMixerMLP(
  418. in_features=config.d_model,
  419. out_features=config.d_model,
  420. config=config,
  421. )
  422. if config.gated_attn:
  423. self.gating_block = PatchTSMixerGatedAttention(in_size=config.d_model, out_size=config.d_model)
  424. def forward(self, hidden: torch.Tensor):
  425. """
  426. Args:
  427. hidden (`torch.Tensor` of shape `(batch_size, num_patches, d_model)`):
  428. Input tensor to the layer.
  429. Returns:
  430. `torch.Tensor`: Transformed tensor.
  431. """
  432. residual = hidden
  433. hidden = self.norm(hidden)
  434. hidden = self.mlp(hidden)
  435. if self.gated_attn:
  436. hidden = self.gating_block(hidden)
  437. out = hidden + residual
  438. return out
  439. class PatchTSMixerLayer(nn.Module):
  440. """
  441. The `PatchTSMixer` layer that does all three kinds of mixing.
  442. Args:
  443. config (`PatchTSMixerConfig`):
  444. Configuration.
  445. """
  446. def __init__(self, config: PatchTSMixerConfig):
  447. super().__init__()
  448. self.patch_mixer = PatchMixerBlock(config=config)
  449. self.feature_mixer = FeatureMixerBlock(config=config)
  450. self.mode = config.mode
  451. if config.mode == "mix_channel":
  452. self.channel_feature_mixer = PatchTSMixerChannelFeatureMixerBlock(config=config)
  453. def forward(self, hidden: torch.Tensor):
  454. """
  455. Args:
  456. hidden (`torch.Tensor` of shape `(batch_size, num_patches, d_model)`):
  457. Input tensor to the layer.
  458. Returns:
  459. `torch.Tensor`: Transformed tensor.
  460. """
  461. if self.mode == "mix_channel":
  462. hidden = self.channel_feature_mixer(hidden)
  463. hidden = self.patch_mixer(hidden)
  464. hidden = self.feature_mixer(hidden) # hidden: (batch_size x num_patches x d_model)
  465. return hidden
  466. class PatchTSMixerBlock(nn.Module):
  467. """The main computing framework of the `PatchTSMixer` model.
  468. Args:
  469. config (`PatchTSMixerConfig`):
  470. Configuration.
  471. """
  472. def __init__(self, config: PatchTSMixerConfig):
  473. super().__init__()
  474. num_layers = config.num_layers
  475. self.mixers = nn.ModuleList([PatchTSMixerLayer(config=config) for _ in range(num_layers)])
  476. def forward(self, hidden_state, output_hidden_states: bool = False):
  477. """
  478. Args:
  479. hidden_state (`torch.Tensor`): The input tensor.
  480. output_hidden_states (`bool`, *optional*, defaults to False.):
  481. Whether to output the hidden states as well.
  482. Returns:
  483. `torch.Tensor`: The embedding. `list`: List of all hidden states if `output_hidden_states` is set to
  484. `True`.
  485. """
  486. all_hidden_states = []
  487. embedding = hidden_state
  488. for mod in self.mixers:
  489. embedding = mod(embedding)
  490. if output_hidden_states:
  491. all_hidden_states.append(embedding)
  492. if output_hidden_states:
  493. return embedding, all_hidden_states
  494. else:
  495. return embedding, None
  496. class PatchTSMixerForPredictionHead(nn.Module):
  497. """Prediction Head for Forecasting
  498. Args:
  499. config (`PatchTSMixerConfig`):
  500. Configuration.
  501. """
  502. def __init__(self, config: PatchTSMixerConfig, distribution_output=None):
  503. super().__init__()
  504. self.prediction_channel_indices = config.prediction_channel_indices
  505. if self.prediction_channel_indices is not None:
  506. self.prediction_channel_indices.sort()
  507. self.dropout_layer = nn.Dropout(config.head_dropout)
  508. if distribution_output is None:
  509. self.base_forecast_block = nn.Linear((config.num_patches * config.d_model), config.prediction_length)
  510. else:
  511. self.base_forecast_block = distribution_output.get_parameter_projection(
  512. config.num_patches * config.d_model
  513. )
  514. self.flatten = nn.Flatten(start_dim=-2)
  515. def forward(self, hidden_features):
  516. """
  517. Args:
  518. hidden_features (`torch.Tensor` of shape `(batch_size, num_patch, d_model)` in `flatten` mode
  519. or `(batch_size, n_vars, num_patch, d_model)` in `common_channel`/`mix_channel` mode.): Input hidden
  520. features.
  521. Returns:
  522. `torch.Tensor` of shape `(batch_size, prediction_length, nvars)`.
  523. """
  524. hidden_features = self.flatten(hidden_features) # [batch_size x n_vars x num_patch * d_model]
  525. hidden_features = self.dropout_layer(hidden_features) # [batch_size x n_vars x num_patch * d_model]
  526. forecast = self.base_forecast_block(hidden_features) # [batch_size x n_vars x prediction_length]
  527. if isinstance(forecast, tuple):
  528. forecast = tuple(z.transpose(-1, -2) for z in forecast)
  529. else:
  530. forecast = forecast.transpose(-1, -2) # [batch_size x prediction_length x n_vars]
  531. if self.prediction_channel_indices is not None:
  532. if isinstance(forecast, tuple):
  533. forecast = tuple(z[..., self.prediction_channel_indices] for z in forecast)
  534. else:
  535. forecast = forecast[..., self.prediction_channel_indices] # [batch_size x prediction_length x n_vars]
  536. return forecast
  537. class PatchTSMixerLinearHead(nn.Module):
  538. """Linear head for Classification and Regression.
  539. Args:
  540. config (`PatchTSMixerConfig`):
  541. Configuration.
  542. """
  543. def __init__(self, config: PatchTSMixerConfig, distribution_output=None):
  544. super().__init__()
  545. self.head_aggregation = config.head_aggregation
  546. self.output_range = config.output_range
  547. if config.head_aggregation is None:
  548. mul_factor = config.num_patches
  549. else:
  550. mul_factor = 1
  551. self.distribution_output = distribution_output
  552. if distribution_output is None:
  553. self.projection = nn.Linear(
  554. config.d_model * config.num_input_channels * mul_factor,
  555. config.num_targets,
  556. )
  557. else:
  558. self.projection = distribution_output.get_parameter_projection(
  559. config.d_model * config.num_input_channels * mul_factor
  560. )
  561. if config.head_aggregation is None:
  562. self.flatten = nn.Flatten(start_dim=-3)
  563. else:
  564. self.flatten = nn.Flatten(start_dim=-2)
  565. self.dropout = nn.Dropout(config.head_dropout)
  566. def forward(self, hidden_features):
  567. """
  568. Args:
  569. hidden_features (`torch.Tensor` of shape `(batch_size x num_patch x d_model)` in `flatten` mode
  570. or `(batch_size x n_vars x num_patch x d_model)` in `common_channel`/`mix_channel` mode.): Input hidden
  571. features.
  572. Returns:
  573. `torch.Tensor` of shape `(batch_size x num_targets)`.
  574. """
  575. # batch_size x d_model x num_patch or batch_size x n_vars x d_model x num_patch
  576. hidden_features = hidden_features.transpose(-1, -2)
  577. if self.head_aggregation == "use_last":
  578. # batch_size x d_model (flatten) or # batch_size x n_vars x d_model (common_channel)
  579. hidden_features = hidden_features[..., -1]
  580. elif self.head_aggregation == "max_pool":
  581. # batch_size x n_vars x d_model or batch_size x d_model
  582. hidden_features = hidden_features.max(dim=-1).values
  583. elif self.head_aggregation == "avg_pool":
  584. # batch_size x n_vars x d_model or batch_size x d_model
  585. hidden_features = hidden_features.mean(dim=-1)
  586. if self.flatten:
  587. hidden_features = self.flatten(hidden_features)
  588. hidden_features = self.dropout(hidden_features)
  589. hidden_features = self.projection(hidden_features) # batch_size x num_targets
  590. if (self.distribution_output is None) and (self.output_range is not None):
  591. hidden_features = (
  592. torch.sigmoid(hidden_features) * (self.output_range[1] - self.output_range[0]) + self.output_range[0]
  593. )
  594. return hidden_features
  595. class PatchTSMixerPreTrainedModel(PreTrainedModel):
  596. # Weight initialization
  597. config_class = PatchTSMixerConfig
  598. base_model_prefix = "model"
  599. main_input_name = "past_values"
  600. supports_gradient_checkpointing = False
  601. def _init_weights(self, module):
  602. """Initialize weights"""
  603. if isinstance(module, PatchTSMixerPositionalEncoding):
  604. # initialize positional encoding
  605. if self.config.positional_encoding_type == "random":
  606. nn.init.normal_(module.position_enc, mean=0.0, std=0.1)
  607. elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)):
  608. module.bias.data.zero_()
  609. module.weight.data.fill_(1.0)
  610. elif isinstance(module, PatchTSMixerBatchNorm):
  611. module.batchnorm.bias.data.zero_()
  612. module.batchnorm.weight.data.fill_(1.0)
  613. elif isinstance(module, nn.Linear):
  614. module.weight.data.normal_(mean=0.0, std=self.config.init_std)
  615. if module.bias is not None:
  616. module.bias.data.zero_()
  617. class PatchTSMixerPretrainHead(nn.Module):
  618. """Pretraining head.
  619. Args:
  620. config (`PatchTSMixerConfig`):
  621. Configuration.
  622. """
  623. def __init__(self, config: PatchTSMixerConfig):
  624. super().__init__()
  625. self.dropout_layer = nn.Dropout(config.head_dropout)
  626. self.base_pt_block = nn.Linear(config.d_model, config.patch_length)
  627. def forward(self, hidden_features):
  628. """
  629. Args:
  630. hidden_features (`torch.Tensor` of shape `(batch_size x num_patch x d_model)` in `flatten` mode
  631. or `(batch_size x n_vars x num_patch x d_model)` in `common_channel`/`mix_channel` mode.): Input hidden
  632. features.
  633. Returns:
  634. `torch.Tensor` of shape `(batch_size x n_vars x num_patch x patch_length)`.
  635. """
  636. hidden_features = self.dropout_layer(hidden_features)
  637. forecast = self.base_pt_block(hidden_features) # [batch_size x n_vars x num_patch x patch_length]
  638. return forecast
  639. # Copied from transformers.models.patchtst.modeling_patchtst.random_masking
  640. def random_masking(
  641. inputs: torch.Tensor,
  642. mask_ratio: float,
  643. unmasked_channel_indices: list = None,
  644. channel_consistent_masking: bool = False,
  645. mask_value: int = 0,
  646. ):
  647. """random_masking: Mask the input considering the control variables.
  648. Args:
  649. inputs (`torch.Tensor` of shape `(batch_size, num_channels, sequence_length, num_features)`):
  650. The input tensor to mask.
  651. mask_ratio (`float`):
  652. Masking ratio applied to mask the input data during random pretraining. It is the number between 0 and 1.
  653. unmasked_channel_indices (list, *optional*):
  654. Indices of channels that will not be masked.
  655. channel_consistent_masking (bool, *optional*, defaults to `False`):
  656. When true, masking will be same across all channels of a timeseries. Otherwise, masking positions will vary
  657. across channels.
  658. mask_value (int, *optional*, defaults to 0):
  659. Define the value of masked patches for pretraining.
  660. Returns:
  661. `tuple(torch.Tensor)`: inputs_mask, masked input, same shape as input Tensor and mask tensor of shape [bs x c x
  662. n]
  663. """
  664. if mask_ratio < 0 or mask_ratio >= 1:
  665. raise ValueError(f"Mask ratio {mask_ratio} has to be between 0 and 1.")
  666. batch_size, num_channels, sequence_length, num_features = inputs.shape
  667. device = inputs.device
  668. len_keep = int(sequence_length * (1 - mask_ratio))
  669. if channel_consistent_masking:
  670. noise = torch.rand(batch_size, 1, sequence_length, device=device) # noise in [0, 1], bs x 1 x L
  671. noise = noise.repeat(1, num_channels, 1) # bs x num_channels x time
  672. else:
  673. # noise in [0, 1], bs x num_channels x L
  674. noise = torch.rand(batch_size, num_channels, sequence_length, device=device)
  675. # mask: [bs x num_channels x num_patch]
  676. mask = torch.ones(batch_size, num_channels, sequence_length, device=device)
  677. mask[:, :, :len_keep] = 0
  678. # sort noise for each sample
  679. ids_shuffle = torch.argsort(noise, dim=-1) # ascend: small is keep, large is remove
  680. ids_restore = torch.argsort(ids_shuffle, dim=-1) # ids_restore: [bs x num_channels x L]
  681. mask = torch.gather(mask, dim=-1, index=ids_restore)
  682. mask = mask.unsqueeze(-1).repeat(1, 1, 1, num_features) # mask: [bs x num_channels x num_patches x patch_length]
  683. if unmasked_channel_indices is not None:
  684. mask[:, unmasked_channel_indices, :, :] = 0
  685. inputs_mask = inputs.masked_fill(mask.bool(), mask_value)
  686. return inputs_mask, mask[..., 0]
  687. # Copied from transformers.models.patchtst.modeling_patchtst.forecast_masking
  688. def forecast_masking(
  689. inputs: torch.Tensor,
  690. num_forecast_mask_patches: Union[list, int],
  691. unmasked_channel_indices: list = None,
  692. mask_value: int = 0,
  693. ):
  694. """Forecast masking that masks the last K patches where K is from the num_forecast_mask_patches.
  695. If num_forecast_mask_patches is a list, samples in the batch will be randomly masked by numbers defined in the list.
  696. Parameters:
  697. inputs (`torch.Tensor`):
  698. Input of shape `(bs, num_channels, num_patch, patch_length)`
  699. num_forecast_mask_patches (`list`):
  700. Number of patches to be masked at the end of each batch sample. e.g. 4 or [3, 5].
  701. unmasked_channel_indices (`list`, *optional*):
  702. Indices of channels that are not masked.
  703. mask_value (`int`, *optional*, defaults to 0):
  704. Values in the masked patches will be filled by `mask_value`.
  705. Returns:
  706. `tuple(torch.Tensor)`: inputs_mask, masked input, same shape as inputs Tensor and Mask tensor of shape `(bs,
  707. num_channels , num_patch)` or `(bs, tsg1, tsg2, num_channels, num_patch)`
  708. """
  709. if isinstance(num_forecast_mask_patches, int):
  710. num_forecast_mask_patches = [num_forecast_mask_patches]
  711. forecast_mask_ratios = [1 for _ in num_forecast_mask_patches]
  712. batch_size, num_channels, sequence_length, num_features = inputs.shape
  713. mask = torch.zeros(batch_size, num_channels, sequence_length, device=inputs.device)
  714. t_list = []
  715. total_length = 0
  716. total_ratio = sum(forecast_mask_ratios)
  717. for patch_length, ratio in zip(num_forecast_mask_patches, forecast_mask_ratios):
  718. if patch_length <= 0 or patch_length >= sequence_length:
  719. raise ValueError(
  720. f"num_forecast_mask_patches {patch_length} should be greater than 0 and less than total patches."
  721. )
  722. temp_len = int(batch_size * ratio / total_ratio)
  723. t_list.append([patch_length, ratio, temp_len])
  724. total_length += temp_len
  725. t_list = sorted(t_list, key=lambda x: x[2])
  726. if total_length < batch_size:
  727. t_list[0][2] = t_list[0][2] + (batch_size - total_length)
  728. elif total_length > batch_size:
  729. t_list[-1][2] = t_list[-1][2] + (total_length - batch_size)
  730. batch1 = 0
  731. for patch_len, _, temp_len in t_list:
  732. batch2 = batch1 + temp_len
  733. mask[batch1:batch2, :, -patch_len:] = 1
  734. batch1 = batch2
  735. perm = torch.randperm(mask.shape[0])
  736. mask = mask[perm]
  737. mask = mask.unsqueeze(-1).repeat(1, 1, 1, num_features) # mask: [bs x num_channels x num_patch x patch_len]
  738. if unmasked_channel_indices is not None:
  739. mask[:, unmasked_channel_indices, :, :] = 0
  740. inputs_mask = inputs.masked_fill(mask.bool(), mask_value)
  741. return inputs_mask, mask[..., 0]
  742. # Copied from transformers.models.patchtst.modeling_patchtst.PatchTSTPatchify with PatchTST->PatchTSMixer
  743. class PatchTSMixerPatchify(nn.Module):
  744. """
  745. A class to patchify the time series sequence into different patches
  746. Returns:
  747. `torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`
  748. """
  749. def __init__(self, config: PatchTSMixerConfig):
  750. super().__init__()
  751. self.sequence_length = config.context_length
  752. self.patch_length = config.patch_length
  753. self.patch_stride = config.patch_stride
  754. if self.sequence_length <= self.patch_length:
  755. raise ValueError(
  756. f"Sequence length ({self.sequence_length}) has to be greater than the patch length ({self.patch_length})"
  757. )
  758. # get the number of patches
  759. self.num_patches = (max(self.sequence_length, self.patch_length) - self.patch_length) // self.patch_stride + 1
  760. new_sequence_length = self.patch_length + self.patch_stride * (self.num_patches - 1)
  761. self.sequence_start = self.sequence_length - new_sequence_length
  762. def forward(self, past_values: torch.Tensor):
  763. """
  764. Parameters:
  765. past_values (`torch.Tensor` of shape `(batch_size, sequence_length, num_channels)`, *required*):
  766. Input for patchification
  767. Returns:
  768. `torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`
  769. """
  770. sequence_length = past_values.shape[-2]
  771. if sequence_length != self.sequence_length:
  772. raise ValueError(
  773. f"Input sequence length ({sequence_length}) doesn't match model configuration ({self.sequence_length})."
  774. )
  775. # output: [bs x new_sequence_length x num_channels]
  776. output = past_values[:, self.sequence_start :, :]
  777. # output: [bs x num_patches x num_input_channels x patch_length]
  778. output = output.unfold(dimension=-2, size=self.patch_length, step=self.patch_stride)
  779. # output: [bs x num_input_channels x num_patches x patch_length]
  780. output = output.transpose(-2, -3).contiguous()
  781. return output
  782. # Copied from transformers.models.patchtst.modeling_patchtst.PatchTSTMasking with PatchTST->PatchTSMixer
  783. class PatchTSMixerMasking(nn.Module):
  784. """
  785. Class to perform random or forecast masking.
  786. Parameters:
  787. config (`PatchTSMixerConfig`): model config
  788. Returns:
  789. x_mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`)
  790. Masked patched input
  791. mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`)
  792. Bool tensor indicating True on masked points
  793. """
  794. def __init__(self, config: PatchTSMixerConfig):
  795. super().__init__()
  796. self.random_mask_ratio = config.random_mask_ratio
  797. self.channel_consistent_masking = config.channel_consistent_masking
  798. self.mask_type = config.mask_type
  799. self.num_forecast_mask_patches = config.num_forecast_mask_patches
  800. self.unmasked_channel_indices = config.unmasked_channel_indices
  801. self.mask_value = config.mask_value
  802. if self.unmasked_channel_indices is not None:
  803. self.unmasked_channel_indices = sorted(self.unmasked_channel_indices)
  804. def forward(self, patch_input: torch.Tensor):
  805. """
  806. Parameters:
  807. patch_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`, *required*):
  808. Patch input
  809. Return:
  810. masked_input (`torch.Tensor` of shape `(batch_size, num_channels, num_patches, patch_length)`)
  811. Masked patched input
  812. mask (`torch.Tensor` of shape `(batch_size, num_channels, num_patches)`)
  813. Bool tensor indicating True on masked points
  814. """
  815. if self.mask_type == "random":
  816. masked_input, mask = random_masking(
  817. inputs=patch_input,
  818. mask_ratio=self.random_mask_ratio,
  819. unmasked_channel_indices=self.unmasked_channel_indices,
  820. channel_consistent_masking=self.channel_consistent_masking,
  821. mask_value=self.mask_value,
  822. )
  823. elif self.mask_type == "forecast":
  824. masked_input, mask = forecast_masking(
  825. inputs=patch_input,
  826. num_forecast_mask_patches=self.num_forecast_mask_patches,
  827. unmasked_channel_indices=self.unmasked_channel_indices,
  828. mask_value=self.mask_value,
  829. )
  830. else:
  831. raise ValueError(f"Invalid mask type {self.mask_type}.")
  832. # mask: [bs x num_input_channels x num_patch]
  833. mask = mask.bool()
  834. return masked_input, mask
  835. # Copied from transformers.models.patchtst.modeling_patchtst.PatchTSTStdScaler with PatchTST->PatchTSMixer
  836. class PatchTSMixerStdScaler(nn.Module):
  837. """
  838. Standardize features by calculating the mean and scaling along the first dimension, and then normalizes it by
  839. subtracting from the mean and dividing by the standard deviation.
  840. """
  841. def __init__(self, config: PatchTSMixerConfig):
  842. super().__init__()
  843. self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
  844. self.keepdim = config.keepdim if hasattr(config, "keepdim") else True
  845. self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-5
  846. def forward(
  847. self, data: torch.Tensor, observed_indicator: torch.Tensor
  848. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  849. """
  850. Parameters:
  851. data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
  852. input for Batch norm calculation
  853. observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
  854. Calculating the scale on the observed indicator.
  855. Returns:
  856. tuple of `torch.Tensor` of shapes
  857. (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
  858. `(batch_size, 1, num_input_channels)`)
  859. """
  860. denominator = observed_indicator.sum(self.dim, keepdim=self.keepdim)
  861. denominator = denominator.clamp_min(1.0)
  862. loc = (data * observed_indicator).sum(self.dim, keepdim=self.keepdim) / denominator
  863. variance = (((data - loc) * observed_indicator) ** 2).sum(self.dim, keepdim=self.keepdim) / denominator
  864. scale = torch.sqrt(variance + self.minimum_scale)
  865. return (data - loc) / scale, loc, scale
  866. # Copied from transformers.models.patchtst.modeling_patchtst.PatchTSTMeanScaler with PatchTST->PatchTSMixer
  867. class PatchTSMixerMeanScaler(nn.Module):
  868. """
  869. Computes a scaling factor as the weighted average absolute value along the first dimension, and scales the data
  870. accordingly.
  871. """
  872. def __init__(self, config: PatchTSMixerConfig):
  873. super().__init__()
  874. self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
  875. self.keepdim = config.keepdim if hasattr(config, "keepdim") else True
  876. self.minimum_scale = config.minimum_scale if hasattr(config, "minimum_scale") else 1e-10
  877. self.default_scale = config.default_scale if hasattr(config, "default_scale") else None
  878. def forward(
  879. self, data: torch.Tensor, observed_indicator: torch.Tensor
  880. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  881. """
  882. Parameters:
  883. data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
  884. input for Batch norm calculation
  885. observed_indicator (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
  886. Calculating the scale on the observed indicator.
  887. Returns:
  888. tuple of `torch.Tensor` of shapes
  889. (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
  890. `(batch_size, 1, num_input_channels)`)
  891. """
  892. ts_sum = (data * observed_indicator).abs().sum(self.dim, keepdim=True)
  893. num_observed = observed_indicator.sum(self.dim, keepdim=True)
  894. scale = ts_sum / torch.clamp(num_observed, min=1)
  895. # If `default_scale` is provided, we use it, otherwise we use the scale
  896. # of the batch.
  897. if self.default_scale is None:
  898. batch_sum = ts_sum.sum(dim=0)
  899. batch_observations = torch.clamp(num_observed.sum(0), min=1)
  900. default_scale = torch.squeeze(batch_sum / batch_observations)
  901. else:
  902. default_scale = self.default_scale * torch.ones_like(scale)
  903. # apply default scale where there are no observations
  904. scale = torch.where(num_observed > 0, scale, default_scale)
  905. # ensure the scale is at least `self.minimum_scale`
  906. scale = torch.clamp(scale, min=self.minimum_scale)
  907. scaled_data = data / scale
  908. if not self.keepdim:
  909. scale = scale.squeeze(dim=self.dim)
  910. return scaled_data, torch.zeros_like(scale), scale
  911. # Copied from transformers.models.patchtst.modeling_patchtst.PatchTSTNOPScaler with PatchTST->PatchTSMixer
  912. class PatchTSMixerNOPScaler(nn.Module):
  913. """
  914. Assigns a scaling factor equal to 1 along the first dimension, and therefore applies no scaling to the input data.
  915. """
  916. def __init__(self, config: PatchTSMixerConfig):
  917. super().__init__()
  918. self.dim = config.scaling_dim if hasattr(config, "scaling_dim") else 1
  919. self.keepdim = config.keepdim if hasattr(config, "keepdim") else True
  920. def forward(
  921. self, data: torch.Tensor, observed_indicator: torch.Tensor = None
  922. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  923. """
  924. Parameters:
  925. data (`torch.Tensor` of shape `(batch_size, sequence_length, num_input_channels)`):
  926. input for Batch norm calculation
  927. Returns:
  928. tuple of `torch.Tensor` of shapes
  929. (`(batch_size, sequence_length, num_input_channels)`,`(batch_size, 1, num_input_channels)`,
  930. `(batch_size, 1, num_input_channels)`)
  931. """
  932. scale = torch.ones_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim)
  933. loc = torch.zeros_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim)
  934. return data, loc, scale
  935. @dataclass
  936. class PatchTSMixerEncoderOutput(ModelOutput):
  937. """
  938. Base class for `PatchTSMixerEncoderOutput`, with potential hidden states.
  939. Args:
  940. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, d_model)`):
  941. Hidden-state at the output of the last layer of the model.
  942. hidden_states (`tuple(torch.FloatTensor)`, *optional*):
  943. Hidden-states of the model at the output of each layer.
  944. """
  945. last_hidden_state: torch.FloatTensor = None
  946. hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  947. class PatchTSMixerEncoder(PatchTSMixerPreTrainedModel):
  948. """
  949. Encoder for PatchTSMixer which inputs patched time-series and outputs patched embeddings.
  950. Args:
  951. config (`PatchTSMixerConfig`):
  952. Configuration.
  953. """
  954. def __init__(self, config: PatchTSMixerConfig):
  955. super().__init__(config)
  956. self.use_return_dict = config.use_return_dict
  957. self.patcher = nn.Linear(config.patch_length, config.d_model)
  958. if config.use_positional_encoding:
  959. self.positional_encoder = PatchTSMixerPositionalEncoding(config=config)
  960. else:
  961. self.positional_encoder = None
  962. self.mlp_mixer_encoder = PatchTSMixerBlock(config=config)
  963. # Initialize weights and apply final processing
  964. if config.post_init:
  965. self.post_init()
  966. @replace_return_docstrings(output_type=PatchTSMixerEncoderOutput, config_class=_CONFIG_FOR_DOC)
  967. def forward(
  968. self,
  969. past_values: torch.Tensor,
  970. output_hidden_states: Optional[bool] = False,
  971. return_dict: Optional[bool] = None,
  972. ) -> Union[Tuple, PatchTSMixerEncoderOutput]:
  973. r"""
  974. Args:
  975. past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
  976. Context values of the time series. For a pretraining task, this denotes the input time series to
  977. predict the masked portion. For a forecasting task, this denotes the history/past time series values.
  978. Similarly, for classification or regression tasks, it denotes the appropriate context values of the
  979. time series.
  980. For univariate time series, `num_input_channels` dimension should be 1. For multivariate time series,
  981. it is greater than 1.
  982. output_hidden_states (`bool`, *optional*):
  983. Whether or not to return the hidden states of all layers.
  984. return_dict (`bool`, *optional*):
  985. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  986. Returns:
  987. `torch.FloatTensor` of shape `(batch_size, n_vars, num_patches, d_model)`
  988. """
  989. return_dict = return_dict if return_dict is not None else self.use_return_dict
  990. # flatten [bs x num_patch x d_model]. common_channel/mix_channel: [bs x n_vars x num_patch x d_model]
  991. patches = self.patcher(past_values)
  992. # add positional encoder
  993. if self.positional_encoder is not None:
  994. patches = self.positional_encoder(patches)
  995. last_hidden_state, hidden_states = self.mlp_mixer_encoder(patches, output_hidden_states=output_hidden_states)
  996. if not return_dict:
  997. return tuple(
  998. v
  999. for v in [
  1000. last_hidden_state,
  1001. hidden_states,
  1002. ]
  1003. )
  1004. return PatchTSMixerEncoderOutput(last_hidden_state=last_hidden_state, hidden_states=hidden_states)
  1005. @dataclass
  1006. class PatchTSMixerModelOutput(ModelOutput):
  1007. """
  1008. Base class for model's outputs, with potential hidden states.
  1009. Args:
  1010. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, d_model)`):
  1011. Hidden-state at the output of the last layer of the model.
  1012. hidden_states (`tuple(torch.FloatTensor)`, *optional*):
  1013. Hidden-states of the model at the output of each layer.
  1014. patch_input (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches, patch_length)`):
  1015. Patched input data to the model.
  1016. mask: (`torch.FloatTensor` of shape `(batch_size, num_channels, num_patches)`,*optional*):
  1017. Bool Tensor indicating True in masked patches and False otherwise.
  1018. loc: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`,*optional*):
  1019. Gives the mean of the context window per channel. Used for revin denorm outside the model, if revin
  1020. enabled.
  1021. scale: (`torch.FloatTensor` of shape `(batch_size, 1, num_channels)`,*optional*):
  1022. Gives the std dev of the context window per channel. Used for revin denorm outside the model, if revin
  1023. enabled.
  1024. """
  1025. last_hidden_state: torch.FloatTensor = None
  1026. hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  1027. patch_input: torch.FloatTensor = None
  1028. mask: Optional[torch.FloatTensor] = None
  1029. loc: Optional[torch.FloatTensor] = None
  1030. scale: Optional[torch.FloatTensor] = None
  1031. @add_start_docstrings(
  1032. "The PatchTSMixer Model for time-series forecasting.",
  1033. PATCHTSMIXER_START_DOCSTRING,
  1034. )
  1035. class PatchTSMixerModel(PatchTSMixerPreTrainedModel):
  1036. def __init__(self, config: PatchTSMixerConfig, mask_input: bool = False):
  1037. super().__init__(config)
  1038. self.use_return_dict = config.use_return_dict
  1039. self.encoder = PatchTSMixerEncoder(config)
  1040. self.patching = PatchTSMixerPatchify(config)
  1041. if mask_input is True:
  1042. self.masking = PatchTSMixerMasking(config)
  1043. else:
  1044. self.masking = None
  1045. if config.scaling == "mean":
  1046. self.scaler = PatchTSMixerMeanScaler(config)
  1047. elif config.scaling == "std" or config.scaling is True:
  1048. self.scaler = PatchTSMixerStdScaler(config)
  1049. else:
  1050. self.scaler = PatchTSMixerNOPScaler(config)
  1051. # Initialize weights and apply final processing
  1052. if config.post_init:
  1053. self.post_init()
  1054. @add_start_docstrings_to_model_forward(PATCHTSMIXER_INPUTS_DOCSTRING)
  1055. @replace_return_docstrings(output_type=PatchTSMixerModelOutput, config_class=_CONFIG_FOR_DOC)
  1056. def forward(
  1057. self,
  1058. past_values: torch.Tensor,
  1059. observed_mask: Optional[torch.Tensor] = None,
  1060. output_hidden_states: Optional[bool] = False,
  1061. return_dict: Optional[bool] = None,
  1062. ) -> PatchTSMixerModelOutput:
  1063. r"""
  1064. observed_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
  1065. Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
  1066. in `[0, 1]`:
  1067. - 1 for values that are **observed**,
  1068. - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
  1069. Returns:
  1070. """
  1071. return_dict = return_dict if return_dict is not None else self.use_return_dict
  1072. mask = None
  1073. if observed_mask is None:
  1074. observed_mask = torch.ones_like(past_values)
  1075. scaled_past_values, loc, scale = self.scaler(past_values, observed_mask)
  1076. patched_x = self.patching(scaled_past_values) # [batch_size x num_input_channels x num_patch x patch_length
  1077. enc_input = patched_x
  1078. if self.masking is not None:
  1079. enc_input, mask = self.masking(patched_x)
  1080. # enc_input: [batch_size x num_input_channels x num_patch x patch_length]
  1081. # mask: [batch_size x num_input_channels x num_patch]
  1082. encoder_output = self.encoder(
  1083. enc_input,
  1084. output_hidden_states=output_hidden_states,
  1085. return_dict=return_dict,
  1086. )
  1087. if isinstance(encoder_output, tuple):
  1088. encoder_output = PatchTSMixerEncoderOutput(*encoder_output)
  1089. if not return_dict:
  1090. return tuple(
  1091. v
  1092. for v in [
  1093. encoder_output.last_hidden_state,
  1094. encoder_output.hidden_states,
  1095. patched_x,
  1096. mask,
  1097. loc,
  1098. scale,
  1099. ]
  1100. )
  1101. return PatchTSMixerModelOutput(
  1102. last_hidden_state=encoder_output.last_hidden_state,
  1103. hidden_states=encoder_output.hidden_states,
  1104. patch_input=patched_x,
  1105. mask=mask,
  1106. loc=loc,
  1107. scale=scale,
  1108. )
  1109. @dataclass
  1110. class PatchTSMixerForPreTrainingOutput(ModelOutput):
  1111. """
  1112. Output type of [`PatchTSMixerForPreTrainingOutput`].
  1113. Args:
  1114. prediction_outputs (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, patch_length)`):
  1115. Prediction output from the pretrain head.
  1116. hidden_states (`tuple(torch.FloatTensor)`, *optional*):
  1117. Hidden-states of the model at the output of each layer.
  1118. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
  1119. Backbone embeddings before passing through the head.
  1120. loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
  1121. Total loss
  1122. """
  1123. loss: Optional[torch.FloatTensor] = None
  1124. prediction_outputs: torch.FloatTensor = None
  1125. last_hidden_state: torch.FloatTensor = None
  1126. hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  1127. class PatchTSMixerForPretraining(PatchTSMixerPreTrainedModel):
  1128. r"""
  1129. `PatchTSMixer` for mask pretraining.
  1130. Args:
  1131. config (`PatchTSMixerConfig`):
  1132. Configuration.
  1133. Returns:
  1134. `None`.
  1135. """
  1136. def __init__(self, config: PatchTSMixerConfig):
  1137. super().__init__(config)
  1138. self.model = PatchTSMixerModel(config, mask_input=True)
  1139. self.head = PatchTSMixerPretrainHead(config=config)
  1140. self.masked_loss = config.masked_loss
  1141. self.use_return_dict = config.use_return_dict
  1142. # Initialize weights and apply final processing
  1143. if config.post_init:
  1144. self.post_init()
  1145. @add_start_docstrings_to_model_forward(PATCHTSMIXER_INPUTS_DOCSTRING)
  1146. @replace_return_docstrings(output_type=PatchTSMixerForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
  1147. def forward(
  1148. self,
  1149. past_values: torch.Tensor,
  1150. observed_mask: Optional[torch.Tensor] = None,
  1151. output_hidden_states: Optional[bool] = False,
  1152. return_loss: bool = True,
  1153. return_dict: Optional[bool] = None,
  1154. ) -> PatchTSMixerForPreTrainingOutput:
  1155. r"""
  1156. observed_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
  1157. Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
  1158. in `[0, 1]`:
  1159. - 1 for values that are **observed**,
  1160. - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
  1161. return_loss (`bool`, *optional*):
  1162. Whether to return the loss in the `forward` call.
  1163. Returns:
  1164. """
  1165. return_dict = return_dict if return_dict is not None else self.use_return_dict
  1166. if self.masked_loss is True:
  1167. loss = torch.nn.MSELoss(reduction="none")
  1168. else:
  1169. loss = torch.nn.MSELoss(reduction="mean")
  1170. # past_values: tensor [batch_size x context_length x num_input_channels]
  1171. model_output = self.model(
  1172. past_values,
  1173. observed_mask=observed_mask,
  1174. output_hidden_states=output_hidden_states,
  1175. return_dict=return_dict,
  1176. ) # x.last_hidden_state: [batch_size x nvars x num_patch x d_model]
  1177. if isinstance(model_output, tuple):
  1178. model_output = PatchTSMixerModelOutput(*model_output)
  1179. x_hat = self.head(model_output.last_hidden_state) # tensor [batch_size x nvars x num_patch x patch_length]
  1180. if return_loss is True:
  1181. loss_val = loss(x_hat, model_output.patch_input)
  1182. else:
  1183. loss_val = None
  1184. # calculate masked_loss
  1185. if self.masked_loss is True and loss_val is not None:
  1186. loss_val = (loss_val.mean(dim=-1) * model_output.mask).sum() / (model_output.mask.sum() + 1e-10)
  1187. if not return_dict:
  1188. return tuple(
  1189. v
  1190. for v in [
  1191. loss_val,
  1192. x_hat,
  1193. model_output.last_hidden_state,
  1194. model_output.hidden_states,
  1195. ]
  1196. )
  1197. return PatchTSMixerForPreTrainingOutput(
  1198. loss=loss_val,
  1199. prediction_outputs=x_hat, # tensor [batch_size x nvars x num_patch x patch_length]
  1200. last_hidden_state=model_output.last_hidden_state, # x: [batch_size x nvars x num_patch x d_model]
  1201. hidden_states=model_output.hidden_states,
  1202. )
  1203. @dataclass
  1204. class PatchTSMixerForPredictionOutput(ModelOutput):
  1205. """
  1206. Output type of [`PatchTSMixerForPredictionOutput`].
  1207. Args:
  1208. prediction_outputs (`torch.FloatTensor` of shape `(batch_size, prediction_length, num_input_channels)`):
  1209. Prediction output from the forecast head.
  1210. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
  1211. Backbone embeddings before passing through the head.
  1212. hidden_states (`tuple(torch.FloatTensor)`, *optional*):
  1213. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  1214. loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
  1215. Total loss.
  1216. loc (`torch.FloatTensor`, *optional* of shape `(batch_size, 1, num_input_channels)`):
  1217. Input mean
  1218. scale (`torch.FloatTensor`, *optional* of shape `(batch_size, 1, num_input_channels)`):
  1219. Input std dev
  1220. """
  1221. loss: Optional[torch.FloatTensor] = None
  1222. prediction_outputs: torch.FloatTensor = None
  1223. last_hidden_state: torch.FloatTensor = None
  1224. hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  1225. loc: torch.FloatTensor = None
  1226. scale: torch.FloatTensor = None
  1227. @dataclass
  1228. class SamplePatchTSMixerPredictionOutput(ModelOutput):
  1229. """
  1230. Base class for time series model's predictions outputs that contains the sampled values from the chosen
  1231. distribution.
  1232. Args:
  1233. sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, prediction_length, number_channels)`):
  1234. Sampled values from the chosen distribution.
  1235. """
  1236. sequences: torch.FloatTensor = None
  1237. @dataclass
  1238. class SamplePatchTSMixerRegressionOutput(ModelOutput):
  1239. """
  1240. Base class for time series model's predictions outputs that contains the sampled values from the chosen
  1241. distribution.
  1242. Args:
  1243. sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, num_targets)`
  1244. Sampled values from the chosen distribution.
  1245. """
  1246. sequences: torch.FloatTensor = None
  1247. # Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.nll
  1248. def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch.Tensor:
  1249. """
  1250. Computes the negative log likelihood loss from input distribution with respect to target.
  1251. """
  1252. return -input.log_prob(target)
  1253. # Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.weighted_average
  1254. def weighted_average(input_tensor: torch.Tensor, weights: Optional[torch.Tensor] = None, dim=None) -> torch.Tensor:
  1255. """
  1256. Computes the weighted average of a given tensor across a given `dim`, masking values associated with weight zero,
  1257. meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`.
  1258. Args:
  1259. input_tensor (`torch.FloatTensor`):
  1260. Input tensor, of which the average must be computed.
  1261. weights (`torch.FloatTensor`, *optional*):
  1262. Weights tensor, of the same shape as `input_tensor`.
  1263. dim (`int`, *optional*):
  1264. The dim along which to average `input_tensor`.
  1265. Returns:
  1266. `torch.FloatTensor`: The tensor with values averaged along the specified `dim`.
  1267. """
  1268. if weights is not None:
  1269. weighted_tensor = torch.where(weights != 0, input_tensor * weights, torch.zeros_like(input_tensor))
  1270. sum_weights = torch.clamp(weights.sum(dim=dim) if dim else weights.sum(), min=1.0)
  1271. return (weighted_tensor.sum(dim=dim) if dim else weighted_tensor.sum()) / sum_weights
  1272. else:
  1273. return input_tensor.mean(dim=dim)
  1274. class PatchTSMixerForPrediction(PatchTSMixerPreTrainedModel):
  1275. r"""
  1276. `PatchTSMixer` for forecasting application.
  1277. Args:
  1278. config (`PatchTSMixerConfig`):
  1279. Configuration.
  1280. Returns:
  1281. `None`.
  1282. """
  1283. def __init__(self, config: PatchTSMixerConfig):
  1284. super().__init__(config)
  1285. self.loss = config.loss
  1286. self.use_return_dict = config.use_return_dict
  1287. self.prediction_channel_indices = config.prediction_channel_indices
  1288. self.num_parallel_samples = config.num_parallel_samples
  1289. if config.loss == "mse":
  1290. self.distribution_output = None
  1291. else:
  1292. dim = config.prediction_length
  1293. distribution_output_map = {
  1294. "student_t": StudentTOutput,
  1295. "normal": NormalOutput,
  1296. "negative_binomial": NegativeBinomialOutput,
  1297. }
  1298. output_class = distribution_output_map.get(config.distribution_output, None)
  1299. if output_class is not None:
  1300. self.distribution_output = output_class(dim=dim)
  1301. else:
  1302. raise ValueError(f"Unknown distribution output {config.distribution_output}")
  1303. self.model = PatchTSMixerModel(config)
  1304. self.head = PatchTSMixerForPredictionHead(
  1305. config=config,
  1306. distribution_output=self.distribution_output,
  1307. )
  1308. # Initialize weights and apply final processing
  1309. if config.post_init:
  1310. self.post_init()
  1311. @add_start_docstrings_to_model_forward(PATCHTSMIXER_INPUTS_DOCSTRING)
  1312. @replace_return_docstrings(output_type=PatchTSMixerForPredictionOutput, config_class=_CONFIG_FOR_DOC)
  1313. def forward(
  1314. self,
  1315. past_values: torch.Tensor,
  1316. observed_mask: Optional[torch.Tensor] = None,
  1317. future_values: Optional[torch.Tensor] = None,
  1318. output_hidden_states: Optional[bool] = False,
  1319. return_loss: bool = True,
  1320. return_dict: Optional[bool] = None,
  1321. ) -> PatchTSMixerForPredictionOutput:
  1322. r"""
  1323. observed_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
  1324. Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
  1325. in `[0, 1]`:
  1326. - 1 for values that are **observed**,
  1327. - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
  1328. future_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,:
  1329. `(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*): Target
  1330. values of the time series, that serve as labels for the model. The `future_values` is what the
  1331. Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT
  1332. required for a pretraining task.
  1333. For a forecasting task, the shape is be `(batch_size, target_len, num_input_channels)`. Even if we want
  1334. to forecast only specific channels by setting the indices in `prediction_channel_indices` parameter,
  1335. pass the target data with all channels, as channel Filtering for both prediction and target will be
  1336. manually applied before the loss computation.
  1337. return_loss (`bool`, *optional*):
  1338. Whether to return the loss in the `forward` call.
  1339. Returns:
  1340. """
  1341. if self.loss == "mse":
  1342. loss = nn.MSELoss(reduction="mean")
  1343. elif self.loss == "nll":
  1344. loss = nll
  1345. else:
  1346. raise ValueError("Invalid loss function: Allowed values: mse and nll")
  1347. return_dict = return_dict if return_dict is not None else self.use_return_dict
  1348. # past_values: tensor [batch_size x context_length x num_input_channels]
  1349. model_output = self.model(
  1350. past_values,
  1351. observed_mask=observed_mask,
  1352. output_hidden_states=output_hidden_states,
  1353. return_dict=return_dict,
  1354. ) # model_output: [batch_size x nvars x num_patch x d_model]
  1355. if isinstance(model_output, tuple):
  1356. model_output = PatchTSMixerModelOutput(*model_output)
  1357. # tensor [batch_size x prediction_length x num_input_channels]
  1358. y_hat = self.head(model_output.last_hidden_state)
  1359. loss_val = None
  1360. if self.prediction_channel_indices is not None:
  1361. if self.distribution_output:
  1362. distribution = self.distribution_output.distribution(
  1363. y_hat,
  1364. loc=model_output.loc[..., self.prediction_channel_indices],
  1365. scale=model_output.scale[..., self.prediction_channel_indices],
  1366. )
  1367. if future_values is not None and return_loss is True:
  1368. loss_val = loss(
  1369. distribution,
  1370. future_values[..., self.prediction_channel_indices],
  1371. )
  1372. # take average of the loss
  1373. loss_val = weighted_average(loss_val)
  1374. else:
  1375. y_hat = (
  1376. y_hat * model_output.scale[..., self.prediction_channel_indices]
  1377. + model_output.loc[..., self.prediction_channel_indices]
  1378. )
  1379. if future_values is not None and return_loss is True:
  1380. loss_val = loss(y_hat, future_values[..., self.prediction_channel_indices])
  1381. else:
  1382. if self.distribution_output:
  1383. distribution = self.distribution_output.distribution(
  1384. y_hat, loc=model_output.loc, scale=model_output.scale
  1385. )
  1386. if future_values is not None and return_loss is True:
  1387. loss_val = loss(distribution, future_values)
  1388. loss_val = weighted_average(loss_val)
  1389. else:
  1390. y_hat = y_hat * model_output.scale + model_output.loc
  1391. if future_values is not None and return_loss is True:
  1392. loss_val = loss(y_hat, future_values)
  1393. if self.prediction_channel_indices is not None:
  1394. loc = model_output.loc[..., self.prediction_channel_indices]
  1395. scale = model_output.scale[..., self.prediction_channel_indices]
  1396. else:
  1397. loc = model_output.loc
  1398. scale = model_output.scale
  1399. if not return_dict:
  1400. return tuple(
  1401. v
  1402. for v in [
  1403. loss_val,
  1404. y_hat,
  1405. model_output.last_hidden_state,
  1406. model_output.hidden_states,
  1407. loc,
  1408. scale,
  1409. ]
  1410. )
  1411. return PatchTSMixerForPredictionOutput(
  1412. loss=loss_val,
  1413. prediction_outputs=y_hat, # tensor [batch_size x prediction_length x num_input_channels]
  1414. last_hidden_state=model_output.last_hidden_state, # x: [batch_size x nvars x num_patch x d_model]
  1415. hidden_states=model_output.hidden_states,
  1416. loc=loc,
  1417. scale=scale,
  1418. )
  1419. def generate(
  1420. self,
  1421. past_values: torch.Tensor,
  1422. observed_mask: Optional[torch.Tensor] = None,
  1423. ) -> SamplePatchTSMixerPredictionOutput:
  1424. """
  1425. Generate sequences of sample predictions from a model with a probability distribution head.
  1426. Args:
  1427. past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
  1428. Past values of the time series that serves as context in order to predict the future.
  1429. observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_input_channels)`, *optional*):
  1430. Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
  1431. in `[0, 1]`:
  1432. - 1 for values that are **observed**,
  1433. - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
  1434. Return:
  1435. [`SamplePatchTSMixerPredictionOutput`] where the outputs `sequences` tensor will have shape `(batch_size,
  1436. number of samples, prediction_length, num_input_channels)`.
  1437. """
  1438. # get number of samples
  1439. num_parallel_samples = self.num_parallel_samples
  1440. # get model output
  1441. outputs = self(
  1442. past_values=past_values,
  1443. future_values=None,
  1444. observed_mask=observed_mask,
  1445. output_hidden_states=False,
  1446. )
  1447. # get distribution
  1448. distribution = self.distribution_output.distribution(
  1449. outputs.prediction_outputs, loc=outputs.loc, scale=outputs.scale
  1450. )
  1451. # get samples: list of [batch_size x prediction_length x num_channels]
  1452. samples = [distribution.sample() for _ in range(num_parallel_samples)]
  1453. # stack tensors
  1454. samples = torch.stack(samples, dim=1) # [batch_size x num_samples x prediction_length x num_channels]
  1455. return SamplePatchTSMixerPredictionOutput(sequences=samples)
  1456. @dataclass
  1457. class PatchTSMixerForTimeSeriesClassificationOutput(ModelOutput):
  1458. """
  1459. Output type of [`PatchTSMixerForTimeSeriesClassificationOutput`].
  1460. Args:
  1461. prediction_outputs (`torch.FloatTensor` of shape `(batch_size, num_labels)`):
  1462. Prediction output from the classfication head.
  1463. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
  1464. Backbone embeddings before passing through the head.
  1465. hidden_states (`tuple(torch.FloatTensor)`, *optional*):
  1466. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  1467. loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
  1468. Total loss.
  1469. """
  1470. loss: Optional[torch.FloatTensor] = None
  1471. prediction_outputs: torch.FloatTensor = None
  1472. last_hidden_state: torch.FloatTensor = None
  1473. hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  1474. class PatchTSMixerForTimeSeriesClassification(PatchTSMixerPreTrainedModel):
  1475. r"""
  1476. `PatchTSMixer` for classification application.
  1477. Args:
  1478. config (`PatchTSMixerConfig`):
  1479. Configuration.
  1480. Returns:
  1481. `None`.
  1482. """
  1483. def __init__(self, config: PatchTSMixerConfig):
  1484. super().__init__(config)
  1485. self.model = PatchTSMixerModel(config)
  1486. self.head = PatchTSMixerLinearHead(
  1487. config=config,
  1488. )
  1489. self.use_return_dict = config.use_return_dict
  1490. if config.scaling in ["std", "mean", True]:
  1491. self.inject_scale = InjectScalerStatistics4D(d_model=config.d_model, num_patches=config.num_patches)
  1492. else:
  1493. self.inject_scale = None
  1494. # Initialize weights and apply final processing
  1495. if config.post_init:
  1496. self.post_init()
  1497. @add_start_docstrings_to_model_forward(PATCHTSMIXER_INPUTS_DOCSTRING)
  1498. @replace_return_docstrings(
  1499. output_type=PatchTSMixerForTimeSeriesClassificationOutput,
  1500. config_class=_CONFIG_FOR_DOC,
  1501. )
  1502. def forward(
  1503. self,
  1504. past_values: torch.Tensor,
  1505. target_values: torch.Tensor = None,
  1506. output_hidden_states: Optional[bool] = False,
  1507. return_loss: bool = True,
  1508. return_dict: Optional[bool] = None,
  1509. ) -> PatchTSMixerForTimeSeriesClassificationOutput:
  1510. r"""
  1511. target_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,
  1512. `(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*): Target
  1513. values of the time series, that serve as labels for the model. The `target_values` is what the
  1514. Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT
  1515. required for a pretraining task.
  1516. For a forecasting task, the shape is be `(batch_size, target_len, num_input_channels)`. Even if we want
  1517. to forecast only specific channels by setting the indices in `prediction_channel_indices` parameter,
  1518. pass the target data with all channels, as channel Filtering for both prediction and target will be
  1519. manually applied before the loss computation.
  1520. For a classification task, it has a shape of `(batch_size,)`.
  1521. For a regression task, it has a shape of `(batch_size, num_targets)`.
  1522. return_loss (`bool`, *optional*):
  1523. Whether to return the loss in the `forward` call.
  1524. Returns:
  1525. """
  1526. loss = torch.nn.CrossEntropyLoss()
  1527. return_dict = return_dict if return_dict is not None else self.use_return_dict
  1528. model_output = self.model(
  1529. past_values,
  1530. output_hidden_states=output_hidden_states,
  1531. return_dict=return_dict,
  1532. ) # x: [batch_size x nvars x num_patch x d_model]
  1533. if isinstance(model_output, tuple):
  1534. model_output = PatchTSMixerModelOutput(*model_output)
  1535. if self.inject_scale is not None:
  1536. model_output.last_hidden_state = self.inject_scale(
  1537. model_output.last_hidden_state,
  1538. loc=model_output.loc,
  1539. scale=model_output.scale,
  1540. ) # x: [batch_size x nvars x num_patch x d_model]
  1541. y_hat = self.head(model_output.last_hidden_state) # tensor [batch_size x n_labels]
  1542. if target_values is not None and return_loss is True:
  1543. loss_val = loss(y_hat, target_values)
  1544. else:
  1545. loss_val = None
  1546. if not return_dict:
  1547. return tuple(
  1548. v
  1549. for v in [
  1550. loss_val,
  1551. y_hat,
  1552. model_output.last_hidden_state,
  1553. model_output.hidden_states,
  1554. ]
  1555. )
  1556. return PatchTSMixerForTimeSeriesClassificationOutput(
  1557. loss=loss_val,
  1558. prediction_outputs=y_hat, # tensor [batch_size x n_labels]
  1559. last_hidden_state=model_output.last_hidden_state, # x: [batch_size x nvars x num_patch x d_model]
  1560. hidden_states=model_output.hidden_states,
  1561. )
  1562. @dataclass
  1563. class PatchTSMixerForRegressionOutput(ModelOutput):
  1564. """
  1565. Output type of [`PatchTSMixerForRegressionOutput`].
  1566. Args:
  1567. regression_outputs (`torch.FloatTensor` of shape `(batch_size, num_targets)`):
  1568. Prediction output from the regression head.
  1569. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
  1570. Backbone embeddings before passing through the head.
  1571. hidden_states (`tuple(torch.FloatTensor)`, *optional*):
  1572. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  1573. loss (*optional*, returned when `y` is provided, `torch.FloatTensor` of shape `()`):
  1574. Total loss.
  1575. """
  1576. loss: Optional[torch.FloatTensor] = None
  1577. regression_outputs: torch.FloatTensor = None
  1578. last_hidden_state: torch.FloatTensor = None
  1579. hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  1580. class InjectScalerStatistics4D(nn.Module):
  1581. def __init__(self, d_model: int, num_patches: int, expansion: int = 2):
  1582. super().__init__()
  1583. self.inverse_trans_expansion = nn.Linear(d_model + 2, expansion * d_model)
  1584. self.inverse_trans_compression = nn.Linear(expansion * d_model, d_model)
  1585. self.map_scale_expansion = nn.Linear(2, 2 * expansion)
  1586. self.map_scale_compression = nn.Linear(2 * expansion, 2)
  1587. self.num_patches = num_patches
  1588. def forward(self, inputs: torch.Tensor, loc: torch.Tensor, scale: torch.Tensor):
  1589. """
  1590. Args:
  1591. inputs (`torch.Tensor` of shape `(batch_size, num_input_channels, num_patch, d_model)`)
  1592. loc (`torch.Tensor` of shape `(batch_size, 1, num_input_channels)`)
  1593. scale (`torch.Tensor` of shape `(batch_size, 1, num_input_channels)`)
  1594. Returns:
  1595. `torch.Tensor` of shape `(batch_size, num_input_channels, num_patch, d_model)`
  1596. """
  1597. mean = loc.transpose(-1, -2) # [batch_size x n_channels x 1 ]
  1598. mean = mean.unsqueeze(-2) # [batch_size x n_channels x 1 x 1]
  1599. mean = mean.repeat(1, 1, self.num_patches, 1) # [batch_size x n_channels x num_patch x 1]
  1600. stdev = scale.transpose(-1, -2) # [batch_size x n_channels x 1 ]
  1601. stdev = stdev.unsqueeze(-2) # [batch_size x n_channels x 1 x 1]
  1602. stdev = stdev.repeat(1, 1, self.num_patches, 1) # [batch_size x n_channels x num_patch x 1]
  1603. concat_stats = torch.cat([mean, stdev], dim=-1) # [batch_size x n_channels x num_patch x 2]
  1604. concat_stats = self.map_scale_expansion(concat_stats) # [batch_size x n_channels x num_patch x (2*expansion)]
  1605. concat_stats = self.map_scale_compression(concat_stats) # [batch_size x n_channels x num_patch x 2]
  1606. inputs = torch.cat([inputs, concat_stats], dim=-1) # [batch_size x channels x num_patch x d_model+2]
  1607. inputs = self.inverse_trans_expansion(inputs) # [batch_size x channels x num_patch x (expansion*d_model)]
  1608. inputs = self.inverse_trans_compression(inputs) # [batch_size x channels x num_patch x d_model]
  1609. return inputs
  1610. class PatchTSMixerForRegression(PatchTSMixerPreTrainedModel):
  1611. r"""
  1612. `PatchTSMixer` for regression application.
  1613. Args:
  1614. config (`PatchTSMixerConfig`):
  1615. Configuration.
  1616. Returns:
  1617. `None`.
  1618. """
  1619. def __init__(self, config: PatchTSMixerConfig):
  1620. super().__init__(config)
  1621. self.model = PatchTSMixerModel(config)
  1622. self.loss = config.loss
  1623. self.distribution_output = config.distribution_output
  1624. self.use_return_dict = config.use_return_dict
  1625. self.num_parallel_samples = config.num_parallel_samples
  1626. if config.loss == "mse":
  1627. self.distribution_output = None
  1628. else:
  1629. distribution_output_map = {
  1630. "student_t": StudentTOutput,
  1631. "normal": NormalOutput,
  1632. "negative_binomial": NegativeBinomialOutput,
  1633. }
  1634. output_class = distribution_output_map.get(config.distribution_output)
  1635. if output_class is not None:
  1636. self.distribution_output = output_class(dim=config.num_targets)
  1637. else:
  1638. raise ValueError(f"Unknown distribution output {config.distribution_output}")
  1639. if config.scaling in ["std", "mean", True]:
  1640. self.inject_scale = InjectScalerStatistics4D(d_model=config.d_model, num_patches=config.num_patches)
  1641. else:
  1642. self.inject_scale = None
  1643. self.head = PatchTSMixerLinearHead(
  1644. config=config,
  1645. distribution_output=self.distribution_output,
  1646. )
  1647. # Initialize weights and apply final processing
  1648. if config.post_init:
  1649. self.post_init()
  1650. @add_start_docstrings_to_model_forward(PATCHTSMIXER_INPUTS_DOCSTRING)
  1651. @replace_return_docstrings(output_type=PatchTSMixerForRegressionOutput, config_class=_CONFIG_FOR_DOC)
  1652. def forward(
  1653. self,
  1654. past_values: torch.Tensor,
  1655. target_values: torch.Tensor = None,
  1656. output_hidden_states: Optional[bool] = False,
  1657. return_loss: bool = True,
  1658. return_dict: Optional[bool] = None,
  1659. ) -> PatchTSMixerForRegressionOutput:
  1660. r"""
  1661. target_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,
  1662. `(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*): Target
  1663. values of the time series, that serve as labels for the model. The `target_values` is what the
  1664. Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT
  1665. required for a pretraining task.
  1666. For a forecasting task, the shape is be `(batch_size, target_len, num_input_channels)`. Even if we want
  1667. to forecast only specific channels by setting the indices in `prediction_channel_indices` parameter,
  1668. pass the target data with all channels, as channel Filtering for both prediction and target will be
  1669. manually applied before the loss computation.
  1670. For a classification task, it has a shape of `(batch_size,)`.
  1671. For a regression task, it has a shape of `(batch_size, num_targets)`.
  1672. return_loss (`bool`, *optional*):
  1673. Whether to return the loss in the `forward` call.
  1674. Returns:
  1675. """
  1676. if self.loss == "mse":
  1677. loss = nn.MSELoss(reduction="mean")
  1678. elif self.loss == "nll":
  1679. loss = nll
  1680. else:
  1681. raise ValueError("Invalid loss function: Allowed values: mse and nll")
  1682. return_dict = return_dict if return_dict is not None else self.use_return_dict
  1683. model_output = self.model(
  1684. past_values,
  1685. output_hidden_states=output_hidden_states,
  1686. return_dict=return_dict,
  1687. ) # model_output: [batch_size x nvars x num_patch x d_model]
  1688. if isinstance(model_output, tuple):
  1689. model_output = PatchTSMixerModelOutput(*model_output)
  1690. if self.inject_scale is not None:
  1691. model_output.last_hidden_state = self.inject_scale(
  1692. model_output.last_hidden_state,
  1693. loc=model_output.loc,
  1694. scale=model_output.scale,
  1695. ) # x: [batch_size x nvars x num_patch x d_model]
  1696. y_hat = self.head(model_output.last_hidden_state) # [batch_size x num_targets]
  1697. if target_values is not None and return_loss is True:
  1698. if self.distribution_output:
  1699. if self.distribution_output == "negative_binomial" and torch.any(target_values < 0):
  1700. raise Exception("target_values cannot be negative for negative_binomial distribution.")
  1701. distribution = self.distribution_output.distribution(y_hat)
  1702. # y_hat should be a 2-tuple, each with dimension [bs, num_targets]
  1703. y_hat = tuple([item.view(-1, self.config.num_targets) for item in y_hat])
  1704. loss_val = loss(distribution, target_values)
  1705. # take average of the loss
  1706. loss_val = weighted_average(loss_val)
  1707. else:
  1708. loss_val = loss(y_hat, target_values)
  1709. else:
  1710. loss_val = None
  1711. if not return_dict:
  1712. return tuple(
  1713. v
  1714. for v in [
  1715. loss_val,
  1716. y_hat,
  1717. model_output.last_hidden_state,
  1718. model_output.hidden_states,
  1719. ]
  1720. )
  1721. return PatchTSMixerForRegressionOutput(
  1722. loss=loss_val,
  1723. regression_outputs=y_hat, # tensor [batch_size x num_targets]
  1724. last_hidden_state=model_output.last_hidden_state, # [batch_size x nvars x num_patch x d_model]
  1725. hidden_states=model_output.hidden_states,
  1726. )
  1727. def generate(
  1728. self,
  1729. past_values: torch.Tensor,
  1730. ) -> SamplePatchTSMixerRegressionOutput:
  1731. """
  1732. Generate sequences of sample predictions from a model with a probability distribution head.
  1733. Args:
  1734. past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
  1735. Past values of the time series that serves as context in order to predict the target values.
  1736. Return:
  1737. [`SamplePatchTSMixerRegressionOutput`] where the outputs `sequences` tensor will have shape `(batch_size,
  1738. number of samples, num_targets)`.
  1739. """
  1740. # get number of samples
  1741. num_parallel_samples = self.num_parallel_samples
  1742. # get model output
  1743. outputs = self(
  1744. past_values=past_values,
  1745. target_values=None,
  1746. output_hidden_states=False,
  1747. )
  1748. # get distribution
  1749. distribution = self.distribution_output.distribution(outputs.regression_outputs)
  1750. # get samples
  1751. samples = [
  1752. distribution.sample() for _ in range(num_parallel_samples)
  1753. ] # samples: list of [batch_size x num_targets]
  1754. # stack tensors
  1755. # [batch_size x num_samples x num_targets]
  1756. samples = torch.stack(samples, dim=1).view(-1, num_parallel_samples, self.config.num_targets)
  1757. return SamplePatchTSMixerRegressionOutput(sequences=samples)