multivariate_normal.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. # mypy: allow-untyped-defs
  2. import math
  3. import torch
  4. from torch.distributions import constraints
  5. from torch.distributions.distribution import Distribution
  6. from torch.distributions.utils import _standard_normal, lazy_property
  7. __all__ = ["MultivariateNormal"]
  8. def _batch_mv(bmat, bvec):
  9. r"""
  10. Performs a batched matrix-vector product, with compatible but different batch shapes.
  11. This function takes as input `bmat`, containing :math:`n \times n` matrices, and
  12. `bvec`, containing length :math:`n` vectors.
  13. Both `bmat` and `bvec` may have any number of leading dimensions, which correspond
  14. to a batch shape. They are not necessarily assumed to have the same batch shape,
  15. just ones which can be broadcasted.
  16. """
  17. return torch.matmul(bmat, bvec.unsqueeze(-1)).squeeze(-1)
  18. def _batch_mahalanobis(bL, bx):
  19. r"""
  20. Computes the squared Mahalanobis distance :math:`\mathbf{x}^\top\mathbf{M}^{-1}\mathbf{x}`
  21. for a factored :math:`\mathbf{M} = \mathbf{L}\mathbf{L}^\top`.
  22. Accepts batches for both bL and bx. They are not necessarily assumed to have the same batch
  23. shape, but `bL` one should be able to broadcasted to `bx` one.
  24. """
  25. n = bx.size(-1)
  26. bx_batch_shape = bx.shape[:-1]
  27. # Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n),
  28. # we are going to make bx have shape (..., 1, j, i, 1, n) to apply batched tri.solve
  29. bx_batch_dims = len(bx_batch_shape)
  30. bL_batch_dims = bL.dim() - 2
  31. outer_batch_dims = bx_batch_dims - bL_batch_dims
  32. old_batch_dims = outer_batch_dims + bL_batch_dims
  33. new_batch_dims = outer_batch_dims + 2 * bL_batch_dims
  34. # Reshape bx with the shape (..., 1, i, j, 1, n)
  35. bx_new_shape = bx.shape[:outer_batch_dims]
  36. for sL, sx in zip(bL.shape[:-2], bx.shape[outer_batch_dims:-1]):
  37. bx_new_shape += (sx // sL, sL)
  38. bx_new_shape += (n,)
  39. bx = bx.reshape(bx_new_shape)
  40. # Permute bx to make it have shape (..., 1, j, i, 1, n)
  41. permute_dims = (
  42. list(range(outer_batch_dims))
  43. + list(range(outer_batch_dims, new_batch_dims, 2))
  44. + list(range(outer_batch_dims + 1, new_batch_dims, 2))
  45. + [new_batch_dims]
  46. )
  47. bx = bx.permute(permute_dims)
  48. flat_L = bL.reshape(-1, n, n) # shape = b x n x n
  49. flat_x = bx.reshape(-1, flat_L.size(0), n) # shape = c x b x n
  50. flat_x_swap = flat_x.permute(1, 2, 0) # shape = b x n x c
  51. M_swap = (
  52. torch.linalg.solve_triangular(flat_L, flat_x_swap, upper=False).pow(2).sum(-2)
  53. ) # shape = b x c
  54. M = M_swap.t() # shape = c x b
  55. # Now we revert the above reshape and permute operators.
  56. permuted_M = M.reshape(bx.shape[:-1]) # shape = (..., 1, j, i, 1)
  57. permute_inv_dims = list(range(outer_batch_dims))
  58. for i in range(bL_batch_dims):
  59. permute_inv_dims += [outer_batch_dims + i, old_batch_dims + i]
  60. reshaped_M = permuted_M.permute(permute_inv_dims) # shape = (..., 1, i, j, 1)
  61. return reshaped_M.reshape(bx_batch_shape)
  62. def _precision_to_scale_tril(P):
  63. # Ref: https://nbviewer.jupyter.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril
  64. Lf = torch.linalg.cholesky(torch.flip(P, (-2, -1)))
  65. L_inv = torch.transpose(torch.flip(Lf, (-2, -1)), -2, -1)
  66. Id = torch.eye(P.shape[-1], dtype=P.dtype, device=P.device)
  67. L = torch.linalg.solve_triangular(L_inv, Id, upper=False)
  68. return L
  69. class MultivariateNormal(Distribution):
  70. r"""
  71. Creates a multivariate normal (also called Gaussian) distribution
  72. parameterized by a mean vector and a covariance matrix.
  73. The multivariate normal distribution can be parameterized either
  74. in terms of a positive definite covariance matrix :math:`\mathbf{\Sigma}`
  75. or a positive definite precision matrix :math:`\mathbf{\Sigma}^{-1}`
  76. or a lower-triangular matrix :math:`\mathbf{L}` with positive-valued
  77. diagonal entries, such that
  78. :math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top`. This triangular matrix
  79. can be obtained via e.g. Cholesky decomposition of the covariance.
  80. Example:
  81. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
  82. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  83. >>> m = MultivariateNormal(torch.zeros(2), torch.eye(2))
  84. >>> m.sample() # normally distributed with mean=`[0,0]` and covariance_matrix=`I`
  85. tensor([-0.2102, -0.5429])
  86. Args:
  87. loc (Tensor): mean of the distribution
  88. covariance_matrix (Tensor): positive-definite covariance matrix
  89. precision_matrix (Tensor): positive-definite precision matrix
  90. scale_tril (Tensor): lower-triangular factor of covariance, with positive-valued diagonal
  91. Note:
  92. Only one of :attr:`covariance_matrix` or :attr:`precision_matrix` or
  93. :attr:`scale_tril` can be specified.
  94. Using :attr:`scale_tril` will be more efficient: all computations internally
  95. are based on :attr:`scale_tril`. If :attr:`covariance_matrix` or
  96. :attr:`precision_matrix` is passed instead, it is only used to compute
  97. the corresponding lower triangular matrices using a Cholesky decomposition.
  98. """
  99. arg_constraints = {
  100. "loc": constraints.real_vector,
  101. "covariance_matrix": constraints.positive_definite,
  102. "precision_matrix": constraints.positive_definite,
  103. "scale_tril": constraints.lower_cholesky,
  104. }
  105. support = constraints.real_vector
  106. has_rsample = True
  107. def __init__(
  108. self,
  109. loc,
  110. covariance_matrix=None,
  111. precision_matrix=None,
  112. scale_tril=None,
  113. validate_args=None,
  114. ):
  115. if loc.dim() < 1:
  116. raise ValueError("loc must be at least one-dimensional.")
  117. if (covariance_matrix is not None) + (scale_tril is not None) + (
  118. precision_matrix is not None
  119. ) != 1:
  120. raise ValueError(
  121. "Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified."
  122. )
  123. if scale_tril is not None:
  124. if scale_tril.dim() < 2:
  125. raise ValueError(
  126. "scale_tril matrix must be at least two-dimensional, "
  127. "with optional leading batch dimensions"
  128. )
  129. batch_shape = torch.broadcast_shapes(scale_tril.shape[:-2], loc.shape[:-1])
  130. self.scale_tril = scale_tril.expand(batch_shape + (-1, -1))
  131. elif covariance_matrix is not None:
  132. if covariance_matrix.dim() < 2:
  133. raise ValueError(
  134. "covariance_matrix must be at least two-dimensional, "
  135. "with optional leading batch dimensions"
  136. )
  137. batch_shape = torch.broadcast_shapes(
  138. covariance_matrix.shape[:-2], loc.shape[:-1]
  139. )
  140. self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1))
  141. else:
  142. if precision_matrix.dim() < 2:
  143. raise ValueError(
  144. "precision_matrix must be at least two-dimensional, "
  145. "with optional leading batch dimensions"
  146. )
  147. batch_shape = torch.broadcast_shapes(
  148. precision_matrix.shape[:-2], loc.shape[:-1]
  149. )
  150. self.precision_matrix = precision_matrix.expand(batch_shape + (-1, -1))
  151. self.loc = loc.expand(batch_shape + (-1,))
  152. event_shape = self.loc.shape[-1:]
  153. super().__init__(batch_shape, event_shape, validate_args=validate_args)
  154. if scale_tril is not None:
  155. self._unbroadcasted_scale_tril = scale_tril
  156. elif covariance_matrix is not None:
  157. self._unbroadcasted_scale_tril = torch.linalg.cholesky(covariance_matrix)
  158. else: # precision_matrix is not None
  159. self._unbroadcasted_scale_tril = _precision_to_scale_tril(precision_matrix)
  160. def expand(self, batch_shape, _instance=None):
  161. new = self._get_checked_instance(MultivariateNormal, _instance)
  162. batch_shape = torch.Size(batch_shape)
  163. loc_shape = batch_shape + self.event_shape
  164. cov_shape = batch_shape + self.event_shape + self.event_shape
  165. new.loc = self.loc.expand(loc_shape)
  166. new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril
  167. if "covariance_matrix" in self.__dict__:
  168. new.covariance_matrix = self.covariance_matrix.expand(cov_shape)
  169. if "scale_tril" in self.__dict__:
  170. new.scale_tril = self.scale_tril.expand(cov_shape)
  171. if "precision_matrix" in self.__dict__:
  172. new.precision_matrix = self.precision_matrix.expand(cov_shape)
  173. super(MultivariateNormal, new).__init__(
  174. batch_shape, self.event_shape, validate_args=False
  175. )
  176. new._validate_args = self._validate_args
  177. return new
  178. @lazy_property
  179. def scale_tril(self):
  180. return self._unbroadcasted_scale_tril.expand(
  181. self._batch_shape + self._event_shape + self._event_shape
  182. )
  183. @lazy_property
  184. def covariance_matrix(self):
  185. return torch.matmul(
  186. self._unbroadcasted_scale_tril, self._unbroadcasted_scale_tril.mT
  187. ).expand(self._batch_shape + self._event_shape + self._event_shape)
  188. @lazy_property
  189. def precision_matrix(self):
  190. return torch.cholesky_inverse(self._unbroadcasted_scale_tril).expand(
  191. self._batch_shape + self._event_shape + self._event_shape
  192. )
  193. @property
  194. def mean(self):
  195. return self.loc
  196. @property
  197. def mode(self):
  198. return self.loc
  199. @property
  200. def variance(self):
  201. return (
  202. self._unbroadcasted_scale_tril.pow(2)
  203. .sum(-1)
  204. .expand(self._batch_shape + self._event_shape)
  205. )
  206. def rsample(self, sample_shape=torch.Size()):
  207. shape = self._extended_shape(sample_shape)
  208. eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
  209. return self.loc + _batch_mv(self._unbroadcasted_scale_tril, eps)
  210. def log_prob(self, value):
  211. if self._validate_args:
  212. self._validate_sample(value)
  213. diff = value - self.loc
  214. M = _batch_mahalanobis(self._unbroadcasted_scale_tril, diff)
  215. half_log_det = (
  216. self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
  217. )
  218. return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + M) - half_log_det
  219. def entropy(self):
  220. half_log_det = (
  221. self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
  222. )
  223. H = 0.5 * self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + half_log_det
  224. if len(self._batch_shape) == 0:
  225. return H
  226. else:
  227. return H.expand(self._batch_shape)