continuous_bernoulli.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. # mypy: allow-untyped-defs
  2. import math
  3. from numbers import Number
  4. import torch
  5. from torch.distributions import constraints
  6. from torch.distributions.exp_family import ExponentialFamily
  7. from torch.distributions.utils import (
  8. broadcast_all,
  9. clamp_probs,
  10. lazy_property,
  11. logits_to_probs,
  12. probs_to_logits,
  13. )
  14. from torch.nn.functional import binary_cross_entropy_with_logits
  15. __all__ = ["ContinuousBernoulli"]
  16. class ContinuousBernoulli(ExponentialFamily):
  17. r"""
  18. Creates a continuous Bernoulli distribution parameterized by :attr:`probs`
  19. or :attr:`logits` (but not both).
  20. The distribution is supported in [0, 1] and parameterized by 'probs' (in
  21. (0,1)) or 'logits' (real-valued). Note that, unlike the Bernoulli, 'probs'
  22. does not correspond to a probability and 'logits' does not correspond to
  23. log-odds, but the same names are used due to the similarity with the
  24. Bernoulli. See [1] for more details.
  25. Example::
  26. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  27. >>> m = ContinuousBernoulli(torch.tensor([0.3]))
  28. >>> m.sample()
  29. tensor([ 0.2538])
  30. Args:
  31. probs (Number, Tensor): (0,1) valued parameters
  32. logits (Number, Tensor): real valued parameters whose sigmoid matches 'probs'
  33. [1] The continuous Bernoulli: fixing a pervasive error in variational
  34. autoencoders, Loaiza-Ganem G and Cunningham JP, NeurIPS 2019.
  35. https://arxiv.org/abs/1907.06845
  36. """
  37. arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
  38. support = constraints.unit_interval
  39. _mean_carrier_measure = 0
  40. has_rsample = True
  41. def __init__(
  42. self, probs=None, logits=None, lims=(0.499, 0.501), validate_args=None
  43. ):
  44. if (probs is None) == (logits is None):
  45. raise ValueError(
  46. "Either `probs` or `logits` must be specified, but not both."
  47. )
  48. if probs is not None:
  49. is_scalar = isinstance(probs, Number)
  50. (self.probs,) = broadcast_all(probs)
  51. # validate 'probs' here if necessary as it is later clamped for numerical stability
  52. # close to 0 and 1, later on; otherwise the clamped 'probs' would always pass
  53. if validate_args is not None:
  54. if not self.arg_constraints["probs"].check(self.probs).all():
  55. raise ValueError("The parameter probs has invalid values")
  56. self.probs = clamp_probs(self.probs)
  57. else:
  58. is_scalar = isinstance(logits, Number)
  59. (self.logits,) = broadcast_all(logits)
  60. self._param = self.probs if probs is not None else self.logits
  61. if is_scalar:
  62. batch_shape = torch.Size()
  63. else:
  64. batch_shape = self._param.size()
  65. self._lims = lims
  66. super().__init__(batch_shape, validate_args=validate_args)
  67. def expand(self, batch_shape, _instance=None):
  68. new = self._get_checked_instance(ContinuousBernoulli, _instance)
  69. new._lims = self._lims
  70. batch_shape = torch.Size(batch_shape)
  71. if "probs" in self.__dict__:
  72. new.probs = self.probs.expand(batch_shape)
  73. new._param = new.probs
  74. if "logits" in self.__dict__:
  75. new.logits = self.logits.expand(batch_shape)
  76. new._param = new.logits
  77. super(ContinuousBernoulli, new).__init__(batch_shape, validate_args=False)
  78. new._validate_args = self._validate_args
  79. return new
  80. def _new(self, *args, **kwargs):
  81. return self._param.new(*args, **kwargs)
  82. def _outside_unstable_region(self):
  83. return torch.max(
  84. torch.le(self.probs, self._lims[0]), torch.gt(self.probs, self._lims[1])
  85. )
  86. def _cut_probs(self):
  87. return torch.where(
  88. self._outside_unstable_region(),
  89. self.probs,
  90. self._lims[0] * torch.ones_like(self.probs),
  91. )
  92. def _cont_bern_log_norm(self):
  93. """computes the log normalizing constant as a function of the 'probs' parameter"""
  94. cut_probs = self._cut_probs()
  95. cut_probs_below_half = torch.where(
  96. torch.le(cut_probs, 0.5), cut_probs, torch.zeros_like(cut_probs)
  97. )
  98. cut_probs_above_half = torch.where(
  99. torch.ge(cut_probs, 0.5), cut_probs, torch.ones_like(cut_probs)
  100. )
  101. log_norm = torch.log(
  102. torch.abs(torch.log1p(-cut_probs) - torch.log(cut_probs))
  103. ) - torch.where(
  104. torch.le(cut_probs, 0.5),
  105. torch.log1p(-2.0 * cut_probs_below_half),
  106. torch.log(2.0 * cut_probs_above_half - 1.0),
  107. )
  108. x = torch.pow(self.probs - 0.5, 2)
  109. taylor = math.log(2.0) + (4.0 / 3.0 + 104.0 / 45.0 * x) * x
  110. return torch.where(self._outside_unstable_region(), log_norm, taylor)
  111. @property
  112. def mean(self):
  113. cut_probs = self._cut_probs()
  114. mus = cut_probs / (2.0 * cut_probs - 1.0) + 1.0 / (
  115. torch.log1p(-cut_probs) - torch.log(cut_probs)
  116. )
  117. x = self.probs - 0.5
  118. taylor = 0.5 + (1.0 / 3.0 + 16.0 / 45.0 * torch.pow(x, 2)) * x
  119. return torch.where(self._outside_unstable_region(), mus, taylor)
  120. @property
  121. def stddev(self):
  122. return torch.sqrt(self.variance)
  123. @property
  124. def variance(self):
  125. cut_probs = self._cut_probs()
  126. vars = cut_probs * (cut_probs - 1.0) / torch.pow(
  127. 1.0 - 2.0 * cut_probs, 2
  128. ) + 1.0 / torch.pow(torch.log1p(-cut_probs) - torch.log(cut_probs), 2)
  129. x = torch.pow(self.probs - 0.5, 2)
  130. taylor = 1.0 / 12.0 - (1.0 / 15.0 - 128.0 / 945.0 * x) * x
  131. return torch.where(self._outside_unstable_region(), vars, taylor)
  132. @lazy_property
  133. def logits(self):
  134. return probs_to_logits(self.probs, is_binary=True)
  135. @lazy_property
  136. def probs(self):
  137. return clamp_probs(logits_to_probs(self.logits, is_binary=True))
  138. @property
  139. def param_shape(self):
  140. return self._param.size()
  141. def sample(self, sample_shape=torch.Size()):
  142. shape = self._extended_shape(sample_shape)
  143. u = torch.rand(shape, dtype=self.probs.dtype, device=self.probs.device)
  144. with torch.no_grad():
  145. return self.icdf(u)
  146. def rsample(self, sample_shape=torch.Size()):
  147. shape = self._extended_shape(sample_shape)
  148. u = torch.rand(shape, dtype=self.probs.dtype, device=self.probs.device)
  149. return self.icdf(u)
  150. def log_prob(self, value):
  151. if self._validate_args:
  152. self._validate_sample(value)
  153. logits, value = broadcast_all(self.logits, value)
  154. return (
  155. -binary_cross_entropy_with_logits(logits, value, reduction="none")
  156. + self._cont_bern_log_norm()
  157. )
  158. def cdf(self, value):
  159. if self._validate_args:
  160. self._validate_sample(value)
  161. cut_probs = self._cut_probs()
  162. cdfs = (
  163. torch.pow(cut_probs, value) * torch.pow(1.0 - cut_probs, 1.0 - value)
  164. + cut_probs
  165. - 1.0
  166. ) / (2.0 * cut_probs - 1.0)
  167. unbounded_cdfs = torch.where(self._outside_unstable_region(), cdfs, value)
  168. return torch.where(
  169. torch.le(value, 0.0),
  170. torch.zeros_like(value),
  171. torch.where(torch.ge(value, 1.0), torch.ones_like(value), unbounded_cdfs),
  172. )
  173. def icdf(self, value):
  174. cut_probs = self._cut_probs()
  175. return torch.where(
  176. self._outside_unstable_region(),
  177. (
  178. torch.log1p(-cut_probs + value * (2.0 * cut_probs - 1.0))
  179. - torch.log1p(-cut_probs)
  180. )
  181. / (torch.log(cut_probs) - torch.log1p(-cut_probs)),
  182. value,
  183. )
  184. def entropy(self):
  185. log_probs0 = torch.log1p(-self.probs)
  186. log_probs1 = torch.log(self.probs)
  187. return (
  188. self.mean * (log_probs0 - log_probs1)
  189. - self._cont_bern_log_norm()
  190. - log_probs0
  191. )
  192. @property
  193. def _natural_params(self):
  194. return (self.logits,)
  195. def _log_normalizer(self, x):
  196. """computes the log normalizing constant as a function of the natural parameter"""
  197. out_unst_reg = torch.max(
  198. torch.le(x, self._lims[0] - 0.5), torch.gt(x, self._lims[1] - 0.5)
  199. )
  200. cut_nat_params = torch.where(
  201. out_unst_reg, x, (self._lims[0] - 0.5) * torch.ones_like(x)
  202. )
  203. log_norm = torch.log(torch.abs(torch.exp(cut_nat_params) - 1.0)) - torch.log(
  204. torch.abs(cut_nat_params)
  205. )
  206. taylor = 0.5 * x + torch.pow(x, 2) / 24.0 - torch.pow(x, 4) / 2880.0
  207. return torch.where(out_unst_reg, log_norm, taylor)