multinomial.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. # mypy: allow-untyped-defs
  2. import torch
  3. from torch import inf
  4. from torch.distributions import Categorical, constraints
  5. from torch.distributions.binomial import Binomial
  6. from torch.distributions.distribution import Distribution
  7. from torch.distributions.utils import broadcast_all
  8. __all__ = ["Multinomial"]
  9. class Multinomial(Distribution):
  10. r"""
  11. Creates a Multinomial distribution parameterized by :attr:`total_count` and
  12. either :attr:`probs` or :attr:`logits` (but not both). The innermost dimension of
  13. :attr:`probs` indexes over categories. All other dimensions index over batches.
  14. Note that :attr:`total_count` need not be specified if only :meth:`log_prob` is
  15. called (see example below)
  16. .. note:: The `probs` argument must be non-negative, finite and have a non-zero sum,
  17. and it will be normalized to sum to 1 along the last dimension. :attr:`probs`
  18. will return this normalized value.
  19. The `logits` argument will be interpreted as unnormalized log probabilities
  20. and can therefore be any real number. It will likewise be normalized so that
  21. the resulting probabilities sum to 1 along the last dimension. :attr:`logits`
  22. will return this normalized value.
  23. - :meth:`sample` requires a single shared `total_count` for all
  24. parameters and samples.
  25. - :meth:`log_prob` allows different `total_count` for each parameter and
  26. sample.
  27. Example::
  28. >>> # xdoctest: +SKIP("FIXME: found invalid values")
  29. >>> m = Multinomial(100, torch.tensor([ 1., 1., 1., 1.]))
  30. >>> x = m.sample() # equal probability of 0, 1, 2, 3
  31. tensor([ 21., 24., 30., 25.])
  32. >>> Multinomial(probs=torch.tensor([1., 1., 1., 1.])).log_prob(x)
  33. tensor([-4.1338])
  34. Args:
  35. total_count (int): number of trials
  36. probs (Tensor): event probabilities
  37. logits (Tensor): event log probabilities (unnormalized)
  38. """
  39. arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
  40. total_count: int
  41. @property
  42. def mean(self):
  43. return self.probs * self.total_count
  44. @property
  45. def variance(self):
  46. return self.total_count * self.probs * (1 - self.probs)
  47. def __init__(self, total_count=1, probs=None, logits=None, validate_args=None):
  48. if not isinstance(total_count, int):
  49. raise NotImplementedError("inhomogeneous total_count is not supported")
  50. self.total_count = total_count
  51. self._categorical = Categorical(probs=probs, logits=logits)
  52. self._binomial = Binomial(total_count=total_count, probs=self.probs)
  53. batch_shape = self._categorical.batch_shape
  54. event_shape = self._categorical.param_shape[-1:]
  55. super().__init__(batch_shape, event_shape, validate_args=validate_args)
  56. def expand(self, batch_shape, _instance=None):
  57. new = self._get_checked_instance(Multinomial, _instance)
  58. batch_shape = torch.Size(batch_shape)
  59. new.total_count = self.total_count
  60. new._categorical = self._categorical.expand(batch_shape)
  61. super(Multinomial, new).__init__(
  62. batch_shape, self.event_shape, validate_args=False
  63. )
  64. new._validate_args = self._validate_args
  65. return new
  66. def _new(self, *args, **kwargs):
  67. return self._categorical._new(*args, **kwargs)
  68. @constraints.dependent_property(is_discrete=True, event_dim=1)
  69. def support(self):
  70. return constraints.multinomial(self.total_count)
  71. @property
  72. def logits(self):
  73. return self._categorical.logits
  74. @property
  75. def probs(self):
  76. return self._categorical.probs
  77. @property
  78. def param_shape(self):
  79. return self._categorical.param_shape
  80. def sample(self, sample_shape=torch.Size()):
  81. sample_shape = torch.Size(sample_shape)
  82. samples = self._categorical.sample(
  83. torch.Size((self.total_count,)) + sample_shape
  84. )
  85. # samples.shape is (total_count, sample_shape, batch_shape), need to change it to
  86. # (sample_shape, batch_shape, total_count)
  87. shifted_idx = list(range(samples.dim()))
  88. shifted_idx.append(shifted_idx.pop(0))
  89. samples = samples.permute(*shifted_idx)
  90. counts = samples.new(self._extended_shape(sample_shape)).zero_()
  91. counts.scatter_add_(-1, samples, torch.ones_like(samples))
  92. return counts.type_as(self.probs)
  93. def entropy(self):
  94. n = torch.tensor(self.total_count)
  95. cat_entropy = self._categorical.entropy()
  96. term1 = n * cat_entropy - torch.lgamma(n + 1)
  97. support = self._binomial.enumerate_support(expand=False)[1:]
  98. binomial_probs = torch.exp(self._binomial.log_prob(support))
  99. weights = torch.lgamma(support + 1)
  100. term2 = (binomial_probs * weights).sum([0, -1])
  101. return term1 + term2
  102. def log_prob(self, value):
  103. if self._validate_args:
  104. self._validate_sample(value)
  105. logits, value = broadcast_all(self.logits, value)
  106. logits = logits.clone(memory_format=torch.contiguous_format)
  107. log_factorial_n = torch.lgamma(value.sum(-1) + 1)
  108. log_factorial_xs = torch.lgamma(value + 1).sum(-1)
  109. logits[(value == 0) & (logits == -inf)] = 0
  110. log_powers = (logits * value).sum(-1)
  111. return log_factorial_n - log_factorial_xs + log_powers