utils.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. # mypy: allow-untyped-defs
  2. from functools import update_wrapper
  3. from numbers import Number
  4. from typing import Any, Dict
  5. import torch
  6. import torch.nn.functional as F
  7. from torch.overrides import is_tensor_like
  8. euler_constant = 0.57721566490153286060 # Euler Mascheroni Constant
  9. __all__ = [
  10. "broadcast_all",
  11. "logits_to_probs",
  12. "clamp_probs",
  13. "probs_to_logits",
  14. "lazy_property",
  15. "tril_matrix_to_vec",
  16. "vec_to_tril_matrix",
  17. ]
  18. def broadcast_all(*values):
  19. r"""
  20. Given a list of values (possibly containing numbers), returns a list where each
  21. value is broadcasted based on the following rules:
  22. - `torch.*Tensor` instances are broadcasted as per :ref:`_broadcasting-semantics`.
  23. - numbers.Number instances (scalars) are upcast to tensors having
  24. the same size and type as the first tensor passed to `values`. If all the
  25. values are scalars, then they are upcasted to scalar Tensors.
  26. Args:
  27. values (list of `numbers.Number`, `torch.*Tensor` or objects implementing __torch_function__)
  28. Raises:
  29. ValueError: if any of the values is not a `numbers.Number` instance,
  30. a `torch.*Tensor` instance, or an instance implementing __torch_function__
  31. """
  32. if not all(is_tensor_like(v) or isinstance(v, Number) for v in values):
  33. raise ValueError(
  34. "Input arguments must all be instances of numbers.Number, "
  35. "torch.Tensor or objects implementing __torch_function__."
  36. )
  37. if not all(is_tensor_like(v) for v in values):
  38. options: Dict[str, Any] = dict(dtype=torch.get_default_dtype())
  39. for value in values:
  40. if isinstance(value, torch.Tensor):
  41. options = dict(dtype=value.dtype, device=value.device)
  42. break
  43. new_values = [
  44. v if is_tensor_like(v) else torch.tensor(v, **options) for v in values
  45. ]
  46. return torch.broadcast_tensors(*new_values)
  47. return torch.broadcast_tensors(*values)
  48. def _standard_normal(shape, dtype, device):
  49. if torch._C._get_tracing_state():
  50. # [JIT WORKAROUND] lack of support for .normal_()
  51. return torch.normal(
  52. torch.zeros(shape, dtype=dtype, device=device),
  53. torch.ones(shape, dtype=dtype, device=device),
  54. )
  55. return torch.empty(shape, dtype=dtype, device=device).normal_()
  56. def _sum_rightmost(value, dim):
  57. r"""
  58. Sum out ``dim`` many rightmost dimensions of a given tensor.
  59. Args:
  60. value (Tensor): A tensor of ``.dim()`` at least ``dim``.
  61. dim (int): The number of rightmost dims to sum out.
  62. """
  63. if dim == 0:
  64. return value
  65. required_shape = value.shape[:-dim] + (-1,)
  66. return value.reshape(required_shape).sum(-1)
  67. def logits_to_probs(logits, is_binary=False):
  68. r"""
  69. Converts a tensor of logits into probabilities. Note that for the
  70. binary case, each value denotes log odds, whereas for the
  71. multi-dimensional case, the values along the last dimension denote
  72. the log probabilities (possibly unnormalized) of the events.
  73. """
  74. if is_binary:
  75. return torch.sigmoid(logits)
  76. return F.softmax(logits, dim=-1)
  77. def clamp_probs(probs):
  78. """Clamps the probabilities to be in the open interval `(0, 1)`.
  79. The probabilities would be clamped between `eps` and `1 - eps`,
  80. and `eps` would be the smallest representable positive number for the input data type.
  81. Args:
  82. probs (Tensor): A tensor of probabilities.
  83. Returns:
  84. Tensor: The clamped probabilities.
  85. Examples:
  86. >>> probs = torch.tensor([0.0, 0.5, 1.0])
  87. >>> clamp_probs(probs)
  88. tensor([1.1921e-07, 5.0000e-01, 1.0000e+00])
  89. >>> probs = torch.tensor([0.0, 0.5, 1.0], dtype=torch.float64)
  90. >>> clamp_probs(probs)
  91. tensor([2.2204e-16, 5.0000e-01, 1.0000e+00], dtype=torch.float64)
  92. """
  93. eps = torch.finfo(probs.dtype).eps
  94. return probs.clamp(min=eps, max=1 - eps)
  95. def probs_to_logits(probs, is_binary=False):
  96. r"""
  97. Converts a tensor of probabilities into logits. For the binary case,
  98. this denotes the probability of occurrence of the event indexed by `1`.
  99. For the multi-dimensional case, the values along the last dimension
  100. denote the probabilities of occurrence of each of the events.
  101. """
  102. ps_clamped = clamp_probs(probs)
  103. if is_binary:
  104. return torch.log(ps_clamped) - torch.log1p(-ps_clamped)
  105. return torch.log(ps_clamped)
  106. class lazy_property:
  107. r"""
  108. Used as a decorator for lazy loading of class attributes. This uses a
  109. non-data descriptor that calls the wrapped method to compute the property on
  110. first call; thereafter replacing the wrapped method into an instance
  111. attribute.
  112. """
  113. def __init__(self, wrapped):
  114. self.wrapped = wrapped
  115. update_wrapper(self, wrapped)
  116. def __get__(self, instance, obj_type=None):
  117. if instance is None:
  118. return _lazy_property_and_property(self.wrapped)
  119. with torch.enable_grad():
  120. value = self.wrapped(instance)
  121. setattr(instance, self.wrapped.__name__, value)
  122. return value
  123. class _lazy_property_and_property(lazy_property, property):
  124. """We want lazy properties to look like multiple things.
  125. * property when Sphinx autodoc looks
  126. * lazy_property when Distribution validate_args looks
  127. """
  128. def __init__(self, wrapped):
  129. property.__init__(self, wrapped)
  130. def tril_matrix_to_vec(mat: torch.Tensor, diag: int = 0) -> torch.Tensor:
  131. r"""
  132. Convert a `D x D` matrix or a batch of matrices into a (batched) vector
  133. which comprises of lower triangular elements from the matrix in row order.
  134. """
  135. n = mat.shape[-1]
  136. if not torch._C._get_tracing_state() and (diag < -n or diag >= n):
  137. raise ValueError(f"diag ({diag}) provided is outside [{-n}, {n-1}].")
  138. arange = torch.arange(n, device=mat.device)
  139. tril_mask = arange < arange.view(-1, 1) + (diag + 1)
  140. vec = mat[..., tril_mask]
  141. return vec
  142. def vec_to_tril_matrix(vec: torch.Tensor, diag: int = 0) -> torch.Tensor:
  143. r"""
  144. Convert a vector or a batch of vectors into a batched `D x D`
  145. lower triangular matrix containing elements from the vector in row order.
  146. """
  147. # +ve root of D**2 + (1+2*diag)*D - |diag| * (diag+1) - 2*vec.shape[-1] = 0
  148. n = (
  149. -(1 + 2 * diag)
  150. + ((1 + 2 * diag) ** 2 + 8 * vec.shape[-1] + 4 * abs(diag) * (diag + 1)) ** 0.5
  151. ) / 2
  152. eps = torch.finfo(vec.dtype).eps
  153. if not torch._C._get_tracing_state() and (round(n) - n > eps):
  154. raise ValueError(
  155. f"The size of last dimension is {vec.shape[-1]} which cannot be expressed as "
  156. + "the lower triangular part of a square D x D matrix."
  157. )
  158. n = round(n.item()) if isinstance(n, torch.Tensor) else round(n)
  159. mat = vec.new_zeros(vec.shape[:-1] + torch.Size((n, n)))
  160. arange = torch.arange(n, device=vec.device)
  161. tril_mask = arange < arange.view(-1, 1) + (diag + 1)
  162. mat[..., tril_mask] = vec
  163. return mat