exponential.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # mypy: allow-untyped-defs
  2. from numbers import Number
  3. import torch
  4. from torch.distributions import constraints
  5. from torch.distributions.exp_family import ExponentialFamily
  6. from torch.distributions.utils import broadcast_all
  7. __all__ = ["Exponential"]
  8. class Exponential(ExponentialFamily):
  9. r"""
  10. Creates a Exponential distribution parameterized by :attr:`rate`.
  11. Example::
  12. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  13. >>> m = Exponential(torch.tensor([1.0]))
  14. >>> m.sample() # Exponential distributed with rate=1
  15. tensor([ 0.1046])
  16. Args:
  17. rate (float or Tensor): rate = 1 / scale of the distribution
  18. """
  19. arg_constraints = {"rate": constraints.positive}
  20. support = constraints.nonnegative
  21. has_rsample = True
  22. _mean_carrier_measure = 0
  23. @property
  24. def mean(self):
  25. return self.rate.reciprocal()
  26. @property
  27. def mode(self):
  28. return torch.zeros_like(self.rate)
  29. @property
  30. def stddev(self):
  31. return self.rate.reciprocal()
  32. @property
  33. def variance(self):
  34. return self.rate.pow(-2)
  35. def __init__(self, rate, validate_args=None):
  36. (self.rate,) = broadcast_all(rate)
  37. batch_shape = torch.Size() if isinstance(rate, Number) else self.rate.size()
  38. super().__init__(batch_shape, validate_args=validate_args)
  39. def expand(self, batch_shape, _instance=None):
  40. new = self._get_checked_instance(Exponential, _instance)
  41. batch_shape = torch.Size(batch_shape)
  42. new.rate = self.rate.expand(batch_shape)
  43. super(Exponential, new).__init__(batch_shape, validate_args=False)
  44. new._validate_args = self._validate_args
  45. return new
  46. def rsample(self, sample_shape=torch.Size()):
  47. shape = self._extended_shape(sample_shape)
  48. return self.rate.new(shape).exponential_() / self.rate
  49. def log_prob(self, value):
  50. if self._validate_args:
  51. self._validate_sample(value)
  52. return self.rate.log() - self.rate * value
  53. def cdf(self, value):
  54. if self._validate_args:
  55. self._validate_sample(value)
  56. return 1 - torch.exp(-self.rate * value)
  57. def icdf(self, value):
  58. return -torch.log1p(-value) / self.rate
  59. def entropy(self):
  60. return 1.0 - torch.log(self.rate)
  61. @property
  62. def _natural_params(self):
  63. return (-self.rate,)
  64. def _log_normalizer(self, x):
  65. return -torch.log(-x)