transforms.py 40 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246
  1. # mypy: allow-untyped-defs
  2. import functools
  3. import math
  4. import numbers
  5. import operator
  6. import weakref
  7. from typing import List
  8. import torch
  9. import torch.nn.functional as F
  10. from torch.distributions import constraints
  11. from torch.distributions.utils import (
  12. _sum_rightmost,
  13. broadcast_all,
  14. lazy_property,
  15. tril_matrix_to_vec,
  16. vec_to_tril_matrix,
  17. )
  18. from torch.nn.functional import pad, softplus
  19. __all__ = [
  20. "AbsTransform",
  21. "AffineTransform",
  22. "CatTransform",
  23. "ComposeTransform",
  24. "CorrCholeskyTransform",
  25. "CumulativeDistributionTransform",
  26. "ExpTransform",
  27. "IndependentTransform",
  28. "LowerCholeskyTransform",
  29. "PositiveDefiniteTransform",
  30. "PowerTransform",
  31. "ReshapeTransform",
  32. "SigmoidTransform",
  33. "SoftplusTransform",
  34. "TanhTransform",
  35. "SoftmaxTransform",
  36. "StackTransform",
  37. "StickBreakingTransform",
  38. "Transform",
  39. "identity_transform",
  40. ]
  41. class Transform:
  42. """
  43. Abstract class for invertable transformations with computable log
  44. det jacobians. They are primarily used in
  45. :class:`torch.distributions.TransformedDistribution`.
  46. Caching is useful for transforms whose inverses are either expensive or
  47. numerically unstable. Note that care must be taken with memoized values
  48. since the autograd graph may be reversed. For example while the following
  49. works with or without caching::
  50. y = t(x)
  51. t.log_abs_det_jacobian(x, y).backward() # x will receive gradients.
  52. However the following will error when caching due to dependency reversal::
  53. y = t(x)
  54. z = t.inv(y)
  55. grad(z.sum(), [y]) # error because z is x
  56. Derived classes should implement one or both of :meth:`_call` or
  57. :meth:`_inverse`. Derived classes that set `bijective=True` should also
  58. implement :meth:`log_abs_det_jacobian`.
  59. Args:
  60. cache_size (int): Size of cache. If zero, no caching is done. If one,
  61. the latest single value is cached. Only 0 and 1 are supported.
  62. Attributes:
  63. domain (:class:`~torch.distributions.constraints.Constraint`):
  64. The constraint representing valid inputs to this transform.
  65. codomain (:class:`~torch.distributions.constraints.Constraint`):
  66. The constraint representing valid outputs to this transform
  67. which are inputs to the inverse transform.
  68. bijective (bool): Whether this transform is bijective. A transform
  69. ``t`` is bijective iff ``t.inv(t(x)) == x`` and
  70. ``t(t.inv(y)) == y`` for every ``x`` in the domain and ``y`` in
  71. the codomain. Transforms that are not bijective should at least
  72. maintain the weaker pseudoinverse properties
  73. ``t(t.inv(t(x)) == t(x)`` and ``t.inv(t(t.inv(y))) == t.inv(y)``.
  74. sign (int or Tensor): For bijective univariate transforms, this
  75. should be +1 or -1 depending on whether transform is monotone
  76. increasing or decreasing.
  77. """
  78. bijective = False
  79. domain: constraints.Constraint
  80. codomain: constraints.Constraint
  81. def __init__(self, cache_size=0):
  82. self._cache_size = cache_size
  83. self._inv = None
  84. if cache_size == 0:
  85. pass # default behavior
  86. elif cache_size == 1:
  87. self._cached_x_y = None, None
  88. else:
  89. raise ValueError("cache_size must be 0 or 1")
  90. super().__init__()
  91. def __getstate__(self):
  92. state = self.__dict__.copy()
  93. state["_inv"] = None
  94. return state
  95. @property
  96. def event_dim(self):
  97. if self.domain.event_dim == self.codomain.event_dim:
  98. return self.domain.event_dim
  99. raise ValueError("Please use either .domain.event_dim or .codomain.event_dim")
  100. @property
  101. def inv(self):
  102. """
  103. Returns the inverse :class:`Transform` of this transform.
  104. This should satisfy ``t.inv.inv is t``.
  105. """
  106. inv = None
  107. if self._inv is not None:
  108. inv = self._inv()
  109. if inv is None:
  110. inv = _InverseTransform(self)
  111. self._inv = weakref.ref(inv)
  112. return inv
  113. @property
  114. def sign(self):
  115. """
  116. Returns the sign of the determinant of the Jacobian, if applicable.
  117. In general this only makes sense for bijective transforms.
  118. """
  119. raise NotImplementedError
  120. def with_cache(self, cache_size=1):
  121. if self._cache_size == cache_size:
  122. return self
  123. if type(self).__init__ is Transform.__init__:
  124. return type(self)(cache_size=cache_size)
  125. raise NotImplementedError(f"{type(self)}.with_cache is not implemented")
  126. def __eq__(self, other):
  127. return self is other
  128. def __ne__(self, other):
  129. # Necessary for Python2
  130. return not self.__eq__(other)
  131. def __call__(self, x):
  132. """
  133. Computes the transform `x => y`.
  134. """
  135. if self._cache_size == 0:
  136. return self._call(x)
  137. x_old, y_old = self._cached_x_y
  138. if x is x_old:
  139. return y_old
  140. y = self._call(x)
  141. self._cached_x_y = x, y
  142. return y
  143. def _inv_call(self, y):
  144. """
  145. Inverts the transform `y => x`.
  146. """
  147. if self._cache_size == 0:
  148. return self._inverse(y)
  149. x_old, y_old = self._cached_x_y
  150. if y is y_old:
  151. return x_old
  152. x = self._inverse(y)
  153. self._cached_x_y = x, y
  154. return x
  155. def _call(self, x):
  156. """
  157. Abstract method to compute forward transformation.
  158. """
  159. raise NotImplementedError
  160. def _inverse(self, y):
  161. """
  162. Abstract method to compute inverse transformation.
  163. """
  164. raise NotImplementedError
  165. def log_abs_det_jacobian(self, x, y):
  166. """
  167. Computes the log det jacobian `log |dy/dx|` given input and output.
  168. """
  169. raise NotImplementedError
  170. def __repr__(self):
  171. return self.__class__.__name__ + "()"
  172. def forward_shape(self, shape):
  173. """
  174. Infers the shape of the forward computation, given the input shape.
  175. Defaults to preserving shape.
  176. """
  177. return shape
  178. def inverse_shape(self, shape):
  179. """
  180. Infers the shapes of the inverse computation, given the output shape.
  181. Defaults to preserving shape.
  182. """
  183. return shape
  184. class _InverseTransform(Transform):
  185. """
  186. Inverts a single :class:`Transform`.
  187. This class is private; please instead use the ``Transform.inv`` property.
  188. """
  189. def __init__(self, transform: Transform):
  190. super().__init__(cache_size=transform._cache_size)
  191. self._inv: Transform = transform
  192. @constraints.dependent_property(is_discrete=False)
  193. def domain(self):
  194. assert self._inv is not None
  195. return self._inv.codomain
  196. @constraints.dependent_property(is_discrete=False)
  197. def codomain(self):
  198. assert self._inv is not None
  199. return self._inv.domain
  200. @property
  201. def bijective(self):
  202. assert self._inv is not None
  203. return self._inv.bijective
  204. @property
  205. def sign(self):
  206. assert self._inv is not None
  207. return self._inv.sign
  208. @property
  209. def inv(self):
  210. return self._inv
  211. def with_cache(self, cache_size=1):
  212. assert self._inv is not None
  213. return self.inv.with_cache(cache_size).inv
  214. def __eq__(self, other):
  215. if not isinstance(other, _InverseTransform):
  216. return False
  217. assert self._inv is not None
  218. return self._inv == other._inv
  219. def __repr__(self):
  220. return f"{self.__class__.__name__}({repr(self._inv)})"
  221. def __call__(self, x):
  222. assert self._inv is not None
  223. return self._inv._inv_call(x)
  224. def log_abs_det_jacobian(self, x, y):
  225. assert self._inv is not None
  226. return -self._inv.log_abs_det_jacobian(y, x)
  227. def forward_shape(self, shape):
  228. return self._inv.inverse_shape(shape)
  229. def inverse_shape(self, shape):
  230. return self._inv.forward_shape(shape)
  231. class ComposeTransform(Transform):
  232. """
  233. Composes multiple transforms in a chain.
  234. The transforms being composed are responsible for caching.
  235. Args:
  236. parts (list of :class:`Transform`): A list of transforms to compose.
  237. cache_size (int): Size of cache. If zero, no caching is done. If one,
  238. the latest single value is cached. Only 0 and 1 are supported.
  239. """
  240. def __init__(self, parts: List[Transform], cache_size=0):
  241. if cache_size:
  242. parts = [part.with_cache(cache_size) for part in parts]
  243. super().__init__(cache_size=cache_size)
  244. self.parts = parts
  245. def __eq__(self, other):
  246. if not isinstance(other, ComposeTransform):
  247. return False
  248. return self.parts == other.parts
  249. @constraints.dependent_property(is_discrete=False)
  250. def domain(self):
  251. if not self.parts:
  252. return constraints.real
  253. domain = self.parts[0].domain
  254. # Adjust event_dim to be maximum among all parts.
  255. event_dim = self.parts[-1].codomain.event_dim
  256. for part in reversed(self.parts):
  257. event_dim += part.domain.event_dim - part.codomain.event_dim
  258. event_dim = max(event_dim, part.domain.event_dim)
  259. assert event_dim >= domain.event_dim
  260. if event_dim > domain.event_dim:
  261. domain = constraints.independent(domain, event_dim - domain.event_dim)
  262. return domain
  263. @constraints.dependent_property(is_discrete=False)
  264. def codomain(self):
  265. if not self.parts:
  266. return constraints.real
  267. codomain = self.parts[-1].codomain
  268. # Adjust event_dim to be maximum among all parts.
  269. event_dim = self.parts[0].domain.event_dim
  270. for part in self.parts:
  271. event_dim += part.codomain.event_dim - part.domain.event_dim
  272. event_dim = max(event_dim, part.codomain.event_dim)
  273. assert event_dim >= codomain.event_dim
  274. if event_dim > codomain.event_dim:
  275. codomain = constraints.independent(codomain, event_dim - codomain.event_dim)
  276. return codomain
  277. @lazy_property
  278. def bijective(self):
  279. return all(p.bijective for p in self.parts)
  280. @lazy_property
  281. def sign(self):
  282. sign = 1
  283. for p in self.parts:
  284. sign = sign * p.sign
  285. return sign
  286. @property
  287. def inv(self):
  288. inv = None
  289. if self._inv is not None:
  290. inv = self._inv()
  291. if inv is None:
  292. inv = ComposeTransform([p.inv for p in reversed(self.parts)])
  293. self._inv = weakref.ref(inv)
  294. inv._inv = weakref.ref(self)
  295. return inv
  296. def with_cache(self, cache_size=1):
  297. if self._cache_size == cache_size:
  298. return self
  299. return ComposeTransform(self.parts, cache_size=cache_size)
  300. def __call__(self, x):
  301. for part in self.parts:
  302. x = part(x)
  303. return x
  304. def log_abs_det_jacobian(self, x, y):
  305. if not self.parts:
  306. return torch.zeros_like(x)
  307. # Compute intermediates. This will be free if parts[:-1] are all cached.
  308. xs = [x]
  309. for part in self.parts[:-1]:
  310. xs.append(part(xs[-1]))
  311. xs.append(y)
  312. terms = []
  313. event_dim = self.domain.event_dim
  314. for part, x, y in zip(self.parts, xs[:-1], xs[1:]):
  315. terms.append(
  316. _sum_rightmost(
  317. part.log_abs_det_jacobian(x, y), event_dim - part.domain.event_dim
  318. )
  319. )
  320. event_dim += part.codomain.event_dim - part.domain.event_dim
  321. return functools.reduce(operator.add, terms)
  322. def forward_shape(self, shape):
  323. for part in self.parts:
  324. shape = part.forward_shape(shape)
  325. return shape
  326. def inverse_shape(self, shape):
  327. for part in reversed(self.parts):
  328. shape = part.inverse_shape(shape)
  329. return shape
  330. def __repr__(self):
  331. fmt_string = self.__class__.__name__ + "(\n "
  332. fmt_string += ",\n ".join([p.__repr__() for p in self.parts])
  333. fmt_string += "\n)"
  334. return fmt_string
  335. identity_transform = ComposeTransform([])
  336. class IndependentTransform(Transform):
  337. """
  338. Wrapper around another transform to treat
  339. ``reinterpreted_batch_ndims``-many extra of the right most dimensions as
  340. dependent. This has no effect on the forward or backward transforms, but
  341. does sum out ``reinterpreted_batch_ndims``-many of the rightmost dimensions
  342. in :meth:`log_abs_det_jacobian`.
  343. Args:
  344. base_transform (:class:`Transform`): A base transform.
  345. reinterpreted_batch_ndims (int): The number of extra rightmost
  346. dimensions to treat as dependent.
  347. """
  348. def __init__(self, base_transform, reinterpreted_batch_ndims, cache_size=0):
  349. super().__init__(cache_size=cache_size)
  350. self.base_transform = base_transform.with_cache(cache_size)
  351. self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
  352. def with_cache(self, cache_size=1):
  353. if self._cache_size == cache_size:
  354. return self
  355. return IndependentTransform(
  356. self.base_transform, self.reinterpreted_batch_ndims, cache_size=cache_size
  357. )
  358. @constraints.dependent_property(is_discrete=False)
  359. def domain(self):
  360. return constraints.independent(
  361. self.base_transform.domain, self.reinterpreted_batch_ndims
  362. )
  363. @constraints.dependent_property(is_discrete=False)
  364. def codomain(self):
  365. return constraints.independent(
  366. self.base_transform.codomain, self.reinterpreted_batch_ndims
  367. )
  368. @property
  369. def bijective(self):
  370. return self.base_transform.bijective
  371. @property
  372. def sign(self):
  373. return self.base_transform.sign
  374. def _call(self, x):
  375. if x.dim() < self.domain.event_dim:
  376. raise ValueError("Too few dimensions on input")
  377. return self.base_transform(x)
  378. def _inverse(self, y):
  379. if y.dim() < self.codomain.event_dim:
  380. raise ValueError("Too few dimensions on input")
  381. return self.base_transform.inv(y)
  382. def log_abs_det_jacobian(self, x, y):
  383. result = self.base_transform.log_abs_det_jacobian(x, y)
  384. result = _sum_rightmost(result, self.reinterpreted_batch_ndims)
  385. return result
  386. def __repr__(self):
  387. return f"{self.__class__.__name__}({repr(self.base_transform)}, {self.reinterpreted_batch_ndims})"
  388. def forward_shape(self, shape):
  389. return self.base_transform.forward_shape(shape)
  390. def inverse_shape(self, shape):
  391. return self.base_transform.inverse_shape(shape)
  392. class ReshapeTransform(Transform):
  393. """
  394. Unit Jacobian transform to reshape the rightmost part of a tensor.
  395. Note that ``in_shape`` and ``out_shape`` must have the same number of
  396. elements, just as for :meth:`torch.Tensor.reshape`.
  397. Arguments:
  398. in_shape (torch.Size): The input event shape.
  399. out_shape (torch.Size): The output event shape.
  400. """
  401. bijective = True
  402. def __init__(self, in_shape, out_shape, cache_size=0):
  403. self.in_shape = torch.Size(in_shape)
  404. self.out_shape = torch.Size(out_shape)
  405. if self.in_shape.numel() != self.out_shape.numel():
  406. raise ValueError("in_shape, out_shape have different numbers of elements")
  407. super().__init__(cache_size=cache_size)
  408. @constraints.dependent_property
  409. def domain(self):
  410. return constraints.independent(constraints.real, len(self.in_shape))
  411. @constraints.dependent_property
  412. def codomain(self):
  413. return constraints.independent(constraints.real, len(self.out_shape))
  414. def with_cache(self, cache_size=1):
  415. if self._cache_size == cache_size:
  416. return self
  417. return ReshapeTransform(self.in_shape, self.out_shape, cache_size=cache_size)
  418. def _call(self, x):
  419. batch_shape = x.shape[: x.dim() - len(self.in_shape)]
  420. return x.reshape(batch_shape + self.out_shape)
  421. def _inverse(self, y):
  422. batch_shape = y.shape[: y.dim() - len(self.out_shape)]
  423. return y.reshape(batch_shape + self.in_shape)
  424. def log_abs_det_jacobian(self, x, y):
  425. batch_shape = x.shape[: x.dim() - len(self.in_shape)]
  426. return x.new_zeros(batch_shape)
  427. def forward_shape(self, shape):
  428. if len(shape) < len(self.in_shape):
  429. raise ValueError("Too few dimensions on input")
  430. cut = len(shape) - len(self.in_shape)
  431. if shape[cut:] != self.in_shape:
  432. raise ValueError(
  433. f"Shape mismatch: expected {shape[cut:]} but got {self.in_shape}"
  434. )
  435. return shape[:cut] + self.out_shape
  436. def inverse_shape(self, shape):
  437. if len(shape) < len(self.out_shape):
  438. raise ValueError("Too few dimensions on input")
  439. cut = len(shape) - len(self.out_shape)
  440. if shape[cut:] != self.out_shape:
  441. raise ValueError(
  442. f"Shape mismatch: expected {shape[cut:]} but got {self.out_shape}"
  443. )
  444. return shape[:cut] + self.in_shape
  445. class ExpTransform(Transform):
  446. r"""
  447. Transform via the mapping :math:`y = \exp(x)`.
  448. """
  449. domain = constraints.real
  450. codomain = constraints.positive
  451. bijective = True
  452. sign = +1
  453. def __eq__(self, other):
  454. return isinstance(other, ExpTransform)
  455. def _call(self, x):
  456. return x.exp()
  457. def _inverse(self, y):
  458. return y.log()
  459. def log_abs_det_jacobian(self, x, y):
  460. return x
  461. class PowerTransform(Transform):
  462. r"""
  463. Transform via the mapping :math:`y = x^{\text{exponent}}`.
  464. """
  465. domain = constraints.positive
  466. codomain = constraints.positive
  467. bijective = True
  468. def __init__(self, exponent, cache_size=0):
  469. super().__init__(cache_size=cache_size)
  470. (self.exponent,) = broadcast_all(exponent)
  471. def with_cache(self, cache_size=1):
  472. if self._cache_size == cache_size:
  473. return self
  474. return PowerTransform(self.exponent, cache_size=cache_size)
  475. @lazy_property
  476. def sign(self):
  477. return self.exponent.sign()
  478. def __eq__(self, other):
  479. if not isinstance(other, PowerTransform):
  480. return False
  481. return self.exponent.eq(other.exponent).all().item()
  482. def _call(self, x):
  483. return x.pow(self.exponent)
  484. def _inverse(self, y):
  485. return y.pow(1 / self.exponent)
  486. def log_abs_det_jacobian(self, x, y):
  487. return (self.exponent * y / x).abs().log()
  488. def forward_shape(self, shape):
  489. return torch.broadcast_shapes(shape, getattr(self.exponent, "shape", ()))
  490. def inverse_shape(self, shape):
  491. return torch.broadcast_shapes(shape, getattr(self.exponent, "shape", ()))
  492. def _clipped_sigmoid(x):
  493. finfo = torch.finfo(x.dtype)
  494. return torch.clamp(torch.sigmoid(x), min=finfo.tiny, max=1.0 - finfo.eps)
  495. class SigmoidTransform(Transform):
  496. r"""
  497. Transform via the mapping :math:`y = \frac{1}{1 + \exp(-x)}` and :math:`x = \text{logit}(y)`.
  498. """
  499. domain = constraints.real
  500. codomain = constraints.unit_interval
  501. bijective = True
  502. sign = +1
  503. def __eq__(self, other):
  504. return isinstance(other, SigmoidTransform)
  505. def _call(self, x):
  506. return _clipped_sigmoid(x)
  507. def _inverse(self, y):
  508. finfo = torch.finfo(y.dtype)
  509. y = y.clamp(min=finfo.tiny, max=1.0 - finfo.eps)
  510. return y.log() - (-y).log1p()
  511. def log_abs_det_jacobian(self, x, y):
  512. return -F.softplus(-x) - F.softplus(x)
  513. class SoftplusTransform(Transform):
  514. r"""
  515. Transform via the mapping :math:`\text{Softplus}(x) = \log(1 + \exp(x))`.
  516. The implementation reverts to the linear function when :math:`x > 20`.
  517. """
  518. domain = constraints.real
  519. codomain = constraints.positive
  520. bijective = True
  521. sign = +1
  522. def __eq__(self, other):
  523. return isinstance(other, SoftplusTransform)
  524. def _call(self, x):
  525. return softplus(x)
  526. def _inverse(self, y):
  527. return (-y).expm1().neg().log() + y
  528. def log_abs_det_jacobian(self, x, y):
  529. return -softplus(-x)
  530. class TanhTransform(Transform):
  531. r"""
  532. Transform via the mapping :math:`y = \tanh(x)`.
  533. It is equivalent to
  534. ```
  535. ComposeTransform([AffineTransform(0., 2.), SigmoidTransform(), AffineTransform(-1., 2.)])
  536. ```
  537. However this might not be numerically stable, thus it is recommended to use `TanhTransform`
  538. instead.
  539. Note that one should use `cache_size=1` when it comes to `NaN/Inf` values.
  540. """
  541. domain = constraints.real
  542. codomain = constraints.interval(-1.0, 1.0)
  543. bijective = True
  544. sign = +1
  545. def __eq__(self, other):
  546. return isinstance(other, TanhTransform)
  547. def _call(self, x):
  548. return x.tanh()
  549. def _inverse(self, y):
  550. # We do not clamp to the boundary here as it may degrade the performance of certain algorithms.
  551. # one should use `cache_size=1` instead
  552. return torch.atanh(y)
  553. def log_abs_det_jacobian(self, x, y):
  554. # We use a formula that is more numerically stable, see details in the following link
  555. # https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/bijectors/tanh.py#L69-L80
  556. return 2.0 * (math.log(2.0) - x - softplus(-2.0 * x))
  557. class AbsTransform(Transform):
  558. r"""
  559. Transform via the mapping :math:`y = |x|`.
  560. """
  561. domain = constraints.real
  562. codomain = constraints.positive
  563. def __eq__(self, other):
  564. return isinstance(other, AbsTransform)
  565. def _call(self, x):
  566. return x.abs()
  567. def _inverse(self, y):
  568. return y
  569. class AffineTransform(Transform):
  570. r"""
  571. Transform via the pointwise affine mapping :math:`y = \text{loc} + \text{scale} \times x`.
  572. Args:
  573. loc (Tensor or float): Location parameter.
  574. scale (Tensor or float): Scale parameter.
  575. event_dim (int): Optional size of `event_shape`. This should be zero
  576. for univariate random variables, 1 for distributions over vectors,
  577. 2 for distributions over matrices, etc.
  578. """
  579. bijective = True
  580. def __init__(self, loc, scale, event_dim=0, cache_size=0):
  581. super().__init__(cache_size=cache_size)
  582. self.loc = loc
  583. self.scale = scale
  584. self._event_dim = event_dim
  585. @property
  586. def event_dim(self):
  587. return self._event_dim
  588. @constraints.dependent_property(is_discrete=False)
  589. def domain(self):
  590. if self.event_dim == 0:
  591. return constraints.real
  592. return constraints.independent(constraints.real, self.event_dim)
  593. @constraints.dependent_property(is_discrete=False)
  594. def codomain(self):
  595. if self.event_dim == 0:
  596. return constraints.real
  597. return constraints.independent(constraints.real, self.event_dim)
  598. def with_cache(self, cache_size=1):
  599. if self._cache_size == cache_size:
  600. return self
  601. return AffineTransform(
  602. self.loc, self.scale, self.event_dim, cache_size=cache_size
  603. )
  604. def __eq__(self, other):
  605. if not isinstance(other, AffineTransform):
  606. return False
  607. if isinstance(self.loc, numbers.Number) and isinstance(
  608. other.loc, numbers.Number
  609. ):
  610. if self.loc != other.loc:
  611. return False
  612. else:
  613. if not (self.loc == other.loc).all().item():
  614. return False
  615. if isinstance(self.scale, numbers.Number) and isinstance(
  616. other.scale, numbers.Number
  617. ):
  618. if self.scale != other.scale:
  619. return False
  620. else:
  621. if not (self.scale == other.scale).all().item():
  622. return False
  623. return True
  624. @property
  625. def sign(self):
  626. if isinstance(self.scale, numbers.Real):
  627. return 1 if float(self.scale) > 0 else -1 if float(self.scale) < 0 else 0
  628. return self.scale.sign()
  629. def _call(self, x):
  630. return self.loc + self.scale * x
  631. def _inverse(self, y):
  632. return (y - self.loc) / self.scale
  633. def log_abs_det_jacobian(self, x, y):
  634. shape = x.shape
  635. scale = self.scale
  636. if isinstance(scale, numbers.Real):
  637. result = torch.full_like(x, math.log(abs(scale)))
  638. else:
  639. result = torch.abs(scale).log()
  640. if self.event_dim:
  641. result_size = result.size()[: -self.event_dim] + (-1,)
  642. result = result.view(result_size).sum(-1)
  643. shape = shape[: -self.event_dim]
  644. return result.expand(shape)
  645. def forward_shape(self, shape):
  646. return torch.broadcast_shapes(
  647. shape, getattr(self.loc, "shape", ()), getattr(self.scale, "shape", ())
  648. )
  649. def inverse_shape(self, shape):
  650. return torch.broadcast_shapes(
  651. shape, getattr(self.loc, "shape", ()), getattr(self.scale, "shape", ())
  652. )
  653. class CorrCholeskyTransform(Transform):
  654. r"""
  655. Transforms an uncontrained real vector :math:`x` with length :math:`D*(D-1)/2` into the
  656. Cholesky factor of a D-dimension correlation matrix. This Cholesky factor is a lower
  657. triangular matrix with positive diagonals and unit Euclidean norm for each row.
  658. The transform is processed as follows:
  659. 1. First we convert x into a lower triangular matrix in row order.
  660. 2. For each row :math:`X_i` of the lower triangular part, we apply a *signed* version of
  661. class :class:`StickBreakingTransform` to transform :math:`X_i` into a
  662. unit Euclidean length vector using the following steps:
  663. - Scales into the interval :math:`(-1, 1)` domain: :math:`r_i = \tanh(X_i)`.
  664. - Transforms into an unsigned domain: :math:`z_i = r_i^2`.
  665. - Applies :math:`s_i = StickBreakingTransform(z_i)`.
  666. - Transforms back into signed domain: :math:`y_i = sign(r_i) * \sqrt{s_i}`.
  667. """
  668. domain = constraints.real_vector
  669. codomain = constraints.corr_cholesky
  670. bijective = True
  671. def _call(self, x):
  672. x = torch.tanh(x)
  673. eps = torch.finfo(x.dtype).eps
  674. x = x.clamp(min=-1 + eps, max=1 - eps)
  675. r = vec_to_tril_matrix(x, diag=-1)
  676. # apply stick-breaking on the squared values
  677. # Note that y = sign(r) * sqrt(z * z1m_cumprod)
  678. # = (sign(r) * sqrt(z)) * sqrt(z1m_cumprod) = r * sqrt(z1m_cumprod)
  679. z = r**2
  680. z1m_cumprod_sqrt = (1 - z).sqrt().cumprod(-1)
  681. # Diagonal elements must be 1.
  682. r = r + torch.eye(r.shape[-1], dtype=r.dtype, device=r.device)
  683. y = r * pad(z1m_cumprod_sqrt[..., :-1], [1, 0], value=1)
  684. return y
  685. def _inverse(self, y):
  686. # inverse stick-breaking
  687. # See: https://mc-stan.org/docs/2_18/reference-manual/cholesky-factors-of-correlation-matrices-1.html
  688. y_cumsum = 1 - torch.cumsum(y * y, dim=-1)
  689. y_cumsum_shifted = pad(y_cumsum[..., :-1], [1, 0], value=1)
  690. y_vec = tril_matrix_to_vec(y, diag=-1)
  691. y_cumsum_vec = tril_matrix_to_vec(y_cumsum_shifted, diag=-1)
  692. t = y_vec / (y_cumsum_vec).sqrt()
  693. # inverse of tanh
  694. x = (t.log1p() - t.neg().log1p()) / 2
  695. return x
  696. def log_abs_det_jacobian(self, x, y, intermediates=None):
  697. # Because domain and codomain are two spaces with different dimensions, determinant of
  698. # Jacobian is not well-defined. We return `log_abs_det_jacobian` of `x` and the
  699. # flattened lower triangular part of `y`.
  700. # See: https://mc-stan.org/docs/2_18/reference-manual/cholesky-factors-of-correlation-matrices-1.html
  701. y1m_cumsum = 1 - (y * y).cumsum(dim=-1)
  702. # by taking diagonal=-2, we don't need to shift z_cumprod to the right
  703. # also works for 2 x 2 matrix
  704. y1m_cumsum_tril = tril_matrix_to_vec(y1m_cumsum, diag=-2)
  705. stick_breaking_logdet = 0.5 * (y1m_cumsum_tril).log().sum(-1)
  706. tanh_logdet = -2 * (x + softplus(-2 * x) - math.log(2.0)).sum(dim=-1)
  707. return stick_breaking_logdet + tanh_logdet
  708. def forward_shape(self, shape):
  709. # Reshape from (..., N) to (..., D, D).
  710. if len(shape) < 1:
  711. raise ValueError("Too few dimensions on input")
  712. N = shape[-1]
  713. D = round((0.25 + 2 * N) ** 0.5 + 0.5)
  714. if D * (D - 1) // 2 != N:
  715. raise ValueError("Input is not a flattend lower-diagonal number")
  716. return shape[:-1] + (D, D)
  717. def inverse_shape(self, shape):
  718. # Reshape from (..., D, D) to (..., N).
  719. if len(shape) < 2:
  720. raise ValueError("Too few dimensions on input")
  721. if shape[-2] != shape[-1]:
  722. raise ValueError("Input is not square")
  723. D = shape[-1]
  724. N = D * (D - 1) // 2
  725. return shape[:-2] + (N,)
  726. class SoftmaxTransform(Transform):
  727. r"""
  728. Transform from unconstrained space to the simplex via :math:`y = \exp(x)` then
  729. normalizing.
  730. This is not bijective and cannot be used for HMC. However this acts mostly
  731. coordinate-wise (except for the final normalization), and thus is
  732. appropriate for coordinate-wise optimization algorithms.
  733. """
  734. domain = constraints.real_vector
  735. codomain = constraints.simplex
  736. def __eq__(self, other):
  737. return isinstance(other, SoftmaxTransform)
  738. def _call(self, x):
  739. logprobs = x
  740. probs = (logprobs - logprobs.max(-1, True)[0]).exp()
  741. return probs / probs.sum(-1, True)
  742. def _inverse(self, y):
  743. probs = y
  744. return probs.log()
  745. def forward_shape(self, shape):
  746. if len(shape) < 1:
  747. raise ValueError("Too few dimensions on input")
  748. return shape
  749. def inverse_shape(self, shape):
  750. if len(shape) < 1:
  751. raise ValueError("Too few dimensions on input")
  752. return shape
  753. class StickBreakingTransform(Transform):
  754. """
  755. Transform from unconstrained space to the simplex of one additional
  756. dimension via a stick-breaking process.
  757. This transform arises as an iterated sigmoid transform in a stick-breaking
  758. construction of the `Dirichlet` distribution: the first logit is
  759. transformed via sigmoid to the first probability and the probability of
  760. everything else, and then the process recurses.
  761. This is bijective and appropriate for use in HMC; however it mixes
  762. coordinates together and is less appropriate for optimization.
  763. """
  764. domain = constraints.real_vector
  765. codomain = constraints.simplex
  766. bijective = True
  767. def __eq__(self, other):
  768. return isinstance(other, StickBreakingTransform)
  769. def _call(self, x):
  770. offset = x.shape[-1] + 1 - x.new_ones(x.shape[-1]).cumsum(-1)
  771. z = _clipped_sigmoid(x - offset.log())
  772. z_cumprod = (1 - z).cumprod(-1)
  773. y = pad(z, [0, 1], value=1) * pad(z_cumprod, [1, 0], value=1)
  774. return y
  775. def _inverse(self, y):
  776. y_crop = y[..., :-1]
  777. offset = y.shape[-1] - y.new_ones(y_crop.shape[-1]).cumsum(-1)
  778. sf = 1 - y_crop.cumsum(-1)
  779. # we clamp to make sure that sf is positive which sometimes does not
  780. # happen when y[-1] ~ 0 or y[:-1].sum() ~ 1
  781. sf = torch.clamp(sf, min=torch.finfo(y.dtype).tiny)
  782. x = y_crop.log() - sf.log() + offset.log()
  783. return x
  784. def log_abs_det_jacobian(self, x, y):
  785. offset = x.shape[-1] + 1 - x.new_ones(x.shape[-1]).cumsum(-1)
  786. x = x - offset.log()
  787. # use the identity 1 - sigmoid(x) = exp(-x) * sigmoid(x)
  788. detJ = (-x + F.logsigmoid(x) + y[..., :-1].log()).sum(-1)
  789. return detJ
  790. def forward_shape(self, shape):
  791. if len(shape) < 1:
  792. raise ValueError("Too few dimensions on input")
  793. return shape[:-1] + (shape[-1] + 1,)
  794. def inverse_shape(self, shape):
  795. if len(shape) < 1:
  796. raise ValueError("Too few dimensions on input")
  797. return shape[:-1] + (shape[-1] - 1,)
  798. class LowerCholeskyTransform(Transform):
  799. """
  800. Transform from unconstrained matrices to lower-triangular matrices with
  801. nonnegative diagonal entries.
  802. This is useful for parameterizing positive definite matrices in terms of
  803. their Cholesky factorization.
  804. """
  805. domain = constraints.independent(constraints.real, 2)
  806. codomain = constraints.lower_cholesky
  807. def __eq__(self, other):
  808. return isinstance(other, LowerCholeskyTransform)
  809. def _call(self, x):
  810. return x.tril(-1) + x.diagonal(dim1=-2, dim2=-1).exp().diag_embed()
  811. def _inverse(self, y):
  812. return y.tril(-1) + y.diagonal(dim1=-2, dim2=-1).log().diag_embed()
  813. class PositiveDefiniteTransform(Transform):
  814. """
  815. Transform from unconstrained matrices to positive-definite matrices.
  816. """
  817. domain = constraints.independent(constraints.real, 2)
  818. codomain = constraints.positive_definite # type: ignore[assignment]
  819. def __eq__(self, other):
  820. return isinstance(other, PositiveDefiniteTransform)
  821. def _call(self, x):
  822. x = LowerCholeskyTransform()(x)
  823. return x @ x.mT
  824. def _inverse(self, y):
  825. y = torch.linalg.cholesky(y)
  826. return LowerCholeskyTransform().inv(y)
  827. class CatTransform(Transform):
  828. """
  829. Transform functor that applies a sequence of transforms `tseq`
  830. component-wise to each submatrix at `dim`, of length `lengths[dim]`,
  831. in a way compatible with :func:`torch.cat`.
  832. Example::
  833. x0 = torch.cat([torch.range(1, 10), torch.range(1, 10)], dim=0)
  834. x = torch.cat([x0, x0], dim=0)
  835. t0 = CatTransform([ExpTransform(), identity_transform], dim=0, lengths=[10, 10])
  836. t = CatTransform([t0, t0], dim=0, lengths=[20, 20])
  837. y = t(x)
  838. """
  839. transforms: List[Transform]
  840. def __init__(self, tseq, dim=0, lengths=None, cache_size=0):
  841. assert all(isinstance(t, Transform) for t in tseq)
  842. if cache_size:
  843. tseq = [t.with_cache(cache_size) for t in tseq]
  844. super().__init__(cache_size=cache_size)
  845. self.transforms = list(tseq)
  846. if lengths is None:
  847. lengths = [1] * len(self.transforms)
  848. self.lengths = list(lengths)
  849. assert len(self.lengths) == len(self.transforms)
  850. self.dim = dim
  851. @lazy_property
  852. def event_dim(self):
  853. return max(t.event_dim for t in self.transforms)
  854. @lazy_property
  855. def length(self):
  856. return sum(self.lengths)
  857. def with_cache(self, cache_size=1):
  858. if self._cache_size == cache_size:
  859. return self
  860. return CatTransform(self.transforms, self.dim, self.lengths, cache_size)
  861. def _call(self, x):
  862. assert -x.dim() <= self.dim < x.dim()
  863. assert x.size(self.dim) == self.length
  864. yslices = []
  865. start = 0
  866. for trans, length in zip(self.transforms, self.lengths):
  867. xslice = x.narrow(self.dim, start, length)
  868. yslices.append(trans(xslice))
  869. start = start + length # avoid += for jit compat
  870. return torch.cat(yslices, dim=self.dim)
  871. def _inverse(self, y):
  872. assert -y.dim() <= self.dim < y.dim()
  873. assert y.size(self.dim) == self.length
  874. xslices = []
  875. start = 0
  876. for trans, length in zip(self.transforms, self.lengths):
  877. yslice = y.narrow(self.dim, start, length)
  878. xslices.append(trans.inv(yslice))
  879. start = start + length # avoid += for jit compat
  880. return torch.cat(xslices, dim=self.dim)
  881. def log_abs_det_jacobian(self, x, y):
  882. assert -x.dim() <= self.dim < x.dim()
  883. assert x.size(self.dim) == self.length
  884. assert -y.dim() <= self.dim < y.dim()
  885. assert y.size(self.dim) == self.length
  886. logdetjacs = []
  887. start = 0
  888. for trans, length in zip(self.transforms, self.lengths):
  889. xslice = x.narrow(self.dim, start, length)
  890. yslice = y.narrow(self.dim, start, length)
  891. logdetjac = trans.log_abs_det_jacobian(xslice, yslice)
  892. if trans.event_dim < self.event_dim:
  893. logdetjac = _sum_rightmost(logdetjac, self.event_dim - trans.event_dim)
  894. logdetjacs.append(logdetjac)
  895. start = start + length # avoid += for jit compat
  896. # Decide whether to concatenate or sum.
  897. dim = self.dim
  898. if dim >= 0:
  899. dim = dim - x.dim()
  900. dim = dim + self.event_dim
  901. if dim < 0:
  902. return torch.cat(logdetjacs, dim=dim)
  903. else:
  904. return sum(logdetjacs)
  905. @property
  906. def bijective(self):
  907. return all(t.bijective for t in self.transforms)
  908. @constraints.dependent_property
  909. def domain(self):
  910. return constraints.cat(
  911. [t.domain for t in self.transforms], self.dim, self.lengths
  912. )
  913. @constraints.dependent_property
  914. def codomain(self):
  915. return constraints.cat(
  916. [t.codomain for t in self.transforms], self.dim, self.lengths
  917. )
  918. class StackTransform(Transform):
  919. """
  920. Transform functor that applies a sequence of transforms `tseq`
  921. component-wise to each submatrix at `dim`
  922. in a way compatible with :func:`torch.stack`.
  923. Example::
  924. x = torch.stack([torch.range(1, 10), torch.range(1, 10)], dim=1)
  925. t = StackTransform([ExpTransform(), identity_transform], dim=1)
  926. y = t(x)
  927. """
  928. transforms: List[Transform]
  929. def __init__(self, tseq, dim=0, cache_size=0):
  930. assert all(isinstance(t, Transform) for t in tseq)
  931. if cache_size:
  932. tseq = [t.with_cache(cache_size) for t in tseq]
  933. super().__init__(cache_size=cache_size)
  934. self.transforms = list(tseq)
  935. self.dim = dim
  936. def with_cache(self, cache_size=1):
  937. if self._cache_size == cache_size:
  938. return self
  939. return StackTransform(self.transforms, self.dim, cache_size)
  940. def _slice(self, z):
  941. return [z.select(self.dim, i) for i in range(z.size(self.dim))]
  942. def _call(self, x):
  943. assert -x.dim() <= self.dim < x.dim()
  944. assert x.size(self.dim) == len(self.transforms)
  945. yslices = []
  946. for xslice, trans in zip(self._slice(x), self.transforms):
  947. yslices.append(trans(xslice))
  948. return torch.stack(yslices, dim=self.dim)
  949. def _inverse(self, y):
  950. assert -y.dim() <= self.dim < y.dim()
  951. assert y.size(self.dim) == len(self.transforms)
  952. xslices = []
  953. for yslice, trans in zip(self._slice(y), self.transforms):
  954. xslices.append(trans.inv(yslice))
  955. return torch.stack(xslices, dim=self.dim)
  956. def log_abs_det_jacobian(self, x, y):
  957. assert -x.dim() <= self.dim < x.dim()
  958. assert x.size(self.dim) == len(self.transforms)
  959. assert -y.dim() <= self.dim < y.dim()
  960. assert y.size(self.dim) == len(self.transforms)
  961. logdetjacs = []
  962. yslices = self._slice(y)
  963. xslices = self._slice(x)
  964. for xslice, yslice, trans in zip(xslices, yslices, self.transforms):
  965. logdetjacs.append(trans.log_abs_det_jacobian(xslice, yslice))
  966. return torch.stack(logdetjacs, dim=self.dim)
  967. @property
  968. def bijective(self):
  969. return all(t.bijective for t in self.transforms)
  970. @constraints.dependent_property
  971. def domain(self):
  972. return constraints.stack([t.domain for t in self.transforms], self.dim)
  973. @constraints.dependent_property
  974. def codomain(self):
  975. return constraints.stack([t.codomain for t in self.transforms], self.dim)
  976. class CumulativeDistributionTransform(Transform):
  977. """
  978. Transform via the cumulative distribution function of a probability distribution.
  979. Args:
  980. distribution (Distribution): Distribution whose cumulative distribution function to use for
  981. the transformation.
  982. Example::
  983. # Construct a Gaussian copula from a multivariate normal.
  984. base_dist = MultivariateNormal(
  985. loc=torch.zeros(2),
  986. scale_tril=LKJCholesky(2).sample(),
  987. )
  988. transform = CumulativeDistributionTransform(Normal(0, 1))
  989. copula = TransformedDistribution(base_dist, [transform])
  990. """
  991. bijective = True
  992. codomain = constraints.unit_interval
  993. sign = +1
  994. def __init__(self, distribution, cache_size=0):
  995. super().__init__(cache_size=cache_size)
  996. self.distribution = distribution
  997. @property
  998. def domain(self):
  999. return self.distribution.support
  1000. def _call(self, x):
  1001. return self.distribution.cdf(x)
  1002. def _inverse(self, y):
  1003. return self.distribution.icdf(y)
  1004. def log_abs_det_jacobian(self, x, y):
  1005. return self.distribution.log_prob(x)
  1006. def with_cache(self, cache_size=1):
  1007. if self._cache_size == cache_size:
  1008. return self
  1009. return CumulativeDistributionTransform(self.distribution, cache_size=cache_size)