time_series_utils.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. # coding=utf-8
  2. # Copyright 2023 The HuggingFace Inc. team.
  3. # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """
  17. Time series distributional output classes and utilities.
  18. """
  19. from typing import Callable, Dict, Optional, Tuple
  20. import torch
  21. from torch import nn
  22. from torch.distributions import (
  23. AffineTransform,
  24. Distribution,
  25. Independent,
  26. NegativeBinomial,
  27. Normal,
  28. StudentT,
  29. TransformedDistribution,
  30. )
  31. class AffineTransformed(TransformedDistribution):
  32. def __init__(self, base_distribution: Distribution, loc=None, scale=None, event_dim=0):
  33. self.scale = 1.0 if scale is None else scale
  34. self.loc = 0.0 if loc is None else loc
  35. super().__init__(base_distribution, [AffineTransform(loc=self.loc, scale=self.scale, event_dim=event_dim)])
  36. @property
  37. def mean(self):
  38. """
  39. Returns the mean of the distribution.
  40. """
  41. return self.base_dist.mean * self.scale + self.loc
  42. @property
  43. def variance(self):
  44. """
  45. Returns the variance of the distribution.
  46. """
  47. return self.base_dist.variance * self.scale**2
  48. @property
  49. def stddev(self):
  50. """
  51. Returns the standard deviation of the distribution.
  52. """
  53. return self.variance.sqrt()
  54. class ParameterProjection(nn.Module):
  55. def __init__(
  56. self, in_features: int, args_dim: Dict[str, int], domain_map: Callable[..., Tuple[torch.Tensor]], **kwargs
  57. ) -> None:
  58. super().__init__(**kwargs)
  59. self.args_dim = args_dim
  60. self.proj = nn.ModuleList([nn.Linear(in_features, dim) for dim in args_dim.values()])
  61. self.domain_map = domain_map
  62. def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
  63. params_unbounded = [proj(x) for proj in self.proj]
  64. return self.domain_map(*params_unbounded)
  65. class LambdaLayer(nn.Module):
  66. def __init__(self, function):
  67. super().__init__()
  68. self.function = function
  69. def forward(self, x, *args):
  70. return self.function(x, *args)
  71. class DistributionOutput:
  72. distribution_class: type
  73. in_features: int
  74. args_dim: Dict[str, int]
  75. def __init__(self, dim: int = 1) -> None:
  76. self.dim = dim
  77. self.args_dim = {k: dim * self.args_dim[k] for k in self.args_dim}
  78. def _base_distribution(self, distr_args):
  79. if self.dim == 1:
  80. return self.distribution_class(*distr_args)
  81. else:
  82. return Independent(self.distribution_class(*distr_args), 1)
  83. def distribution(
  84. self,
  85. distr_args,
  86. loc: Optional[torch.Tensor] = None,
  87. scale: Optional[torch.Tensor] = None,
  88. ) -> Distribution:
  89. distr = self._base_distribution(distr_args)
  90. if loc is None and scale is None:
  91. return distr
  92. else:
  93. return AffineTransformed(distr, loc=loc, scale=scale, event_dim=self.event_dim)
  94. @property
  95. def event_shape(self) -> Tuple:
  96. r"""
  97. Shape of each individual event contemplated by the distributions that this object constructs.
  98. """
  99. return () if self.dim == 1 else (self.dim,)
  100. @property
  101. def event_dim(self) -> int:
  102. r"""
  103. Number of event dimensions, i.e., length of the `event_shape` tuple, of the distributions that this object
  104. constructs.
  105. """
  106. return len(self.event_shape)
  107. @property
  108. def value_in_support(self) -> float:
  109. r"""
  110. A float that will have a valid numeric value when computing the log-loss of the corresponding distribution. By
  111. default 0.0. This value will be used when padding data series.
  112. """
  113. return 0.0
  114. def get_parameter_projection(self, in_features: int) -> nn.Module:
  115. r"""
  116. Return the parameter projection layer that maps the input to the appropriate parameters of the distribution.
  117. """
  118. return ParameterProjection(
  119. in_features=in_features,
  120. args_dim=self.args_dim,
  121. domain_map=LambdaLayer(self.domain_map),
  122. )
  123. def domain_map(self, *args: torch.Tensor):
  124. r"""
  125. Converts arguments to the right shape and domain. The domain depends on the type of distribution, while the
  126. correct shape is obtained by reshaping the trailing axis in such a way that the returned tensors define a
  127. distribution of the right event_shape.
  128. """
  129. raise NotImplementedError()
  130. @staticmethod
  131. def squareplus(x: torch.Tensor) -> torch.Tensor:
  132. r"""
  133. Helper to map inputs to the positive orthant by applying the square-plus operation. Reference:
  134. https://twitter.com/jon_barron/status/1387167648669048833
  135. """
  136. return (x + torch.sqrt(torch.square(x) + 4.0)) / 2.0
  137. class StudentTOutput(DistributionOutput):
  138. """
  139. Student-T distribution output class.
  140. """
  141. args_dim: Dict[str, int] = {"df": 1, "loc": 1, "scale": 1}
  142. distribution_class: type = StudentT
  143. @classmethod
  144. def domain_map(cls, df: torch.Tensor, loc: torch.Tensor, scale: torch.Tensor):
  145. scale = cls.squareplus(scale).clamp_min(torch.finfo(scale.dtype).eps)
  146. df = 2.0 + cls.squareplus(df)
  147. return df.squeeze(-1), loc.squeeze(-1), scale.squeeze(-1)
  148. class NormalOutput(DistributionOutput):
  149. """
  150. Normal distribution output class.
  151. """
  152. args_dim: Dict[str, int] = {"loc": 1, "scale": 1}
  153. distribution_class: type = Normal
  154. @classmethod
  155. def domain_map(cls, loc: torch.Tensor, scale: torch.Tensor):
  156. scale = cls.squareplus(scale).clamp_min(torch.finfo(scale.dtype).eps)
  157. return loc.squeeze(-1), scale.squeeze(-1)
  158. class NegativeBinomialOutput(DistributionOutput):
  159. """
  160. Negative Binomial distribution output class.
  161. """
  162. args_dim: Dict[str, int] = {"total_count": 1, "logits": 1}
  163. distribution_class: type = NegativeBinomial
  164. @classmethod
  165. def domain_map(cls, total_count: torch.Tensor, logits: torch.Tensor):
  166. total_count = cls.squareplus(total_count)
  167. return total_count.squeeze(-1), logits.squeeze(-1)
  168. def _base_distribution(self, distr_args) -> Distribution:
  169. total_count, logits = distr_args
  170. if self.dim == 1:
  171. return self.distribution_class(total_count=total_count, logits=logits)
  172. else:
  173. return Independent(self.distribution_class(total_count=total_count, logits=logits), 1)
  174. # Overwrites the parent class method. We cannot scale using the affine
  175. # transformation since negative binomial should return integers. Instead
  176. # we scale the parameters.
  177. def distribution(
  178. self, distr_args, loc: Optional[torch.Tensor] = None, scale: Optional[torch.Tensor] = None
  179. ) -> Distribution:
  180. total_count, logits = distr_args
  181. if scale is not None:
  182. # See scaling property of Gamma.
  183. logits += scale.log()
  184. return self._base_distribution((total_count, logits))