_cxx_pytree.py 34 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006
  1. """
  2. Contains utility functions for working with nested python data structures.
  3. A *pytree* is Python nested data structure. It is a tree in the sense that
  4. nodes are Python collections (e.g., list, tuple, dict) and the leaves are
  5. Python values. Furthermore, a pytree should not contain reference cycles.
  6. pytrees are useful for working with nested collections of Tensors. For example,
  7. one can use `tree_map` to map a function over all Tensors inside some nested
  8. collection of Tensors and `tree_leaves` to get a flat list of all Tensors
  9. inside some nested collection. pytrees are helpful for implementing nested
  10. collection support for PyTorch APIs.
  11. """
  12. import functools
  13. import sys
  14. import types
  15. from typing import (
  16. Any,
  17. Callable,
  18. Iterable,
  19. List,
  20. Optional,
  21. overload,
  22. Tuple,
  23. Type,
  24. TypeVar,
  25. Union,
  26. )
  27. from typing_extensions import deprecated
  28. import torch
  29. if torch._running_with_deploy(): # type: ignore[no-untyped-call]
  30. raise ImportError("C++ pytree utilities do not work with torch::deploy.")
  31. import optree
  32. from optree import PyTreeSpec # direct import for type annotations
  33. from torch.utils._pytree import KeyEntry
  34. __all__ = [
  35. "PyTree",
  36. "Context",
  37. "FlattenFunc",
  38. "UnflattenFunc",
  39. "DumpableContext",
  40. "ToDumpableContextFn",
  41. "FromDumpableContextFn",
  42. "TreeSpec",
  43. "LeafSpec",
  44. "keystr",
  45. "key_get",
  46. "register_pytree_node",
  47. "tree_flatten",
  48. "tree_flatten_with_path",
  49. "tree_unflatten",
  50. "tree_iter",
  51. "tree_leaves",
  52. "tree_leaves_with_path",
  53. "tree_structure",
  54. "tree_map",
  55. "tree_map_with_path",
  56. "tree_map_",
  57. "tree_map_only",
  58. "tree_map_only_",
  59. "tree_all",
  60. "tree_any",
  61. "tree_all_only",
  62. "tree_any_only",
  63. "treespec_dumps",
  64. "treespec_loads",
  65. "treespec_pprint",
  66. ]
  67. T = TypeVar("T")
  68. S = TypeVar("S")
  69. U = TypeVar("U")
  70. R = TypeVar("R")
  71. Context = Any
  72. PyTree = Any
  73. TreeSpec = PyTreeSpec
  74. FlattenFunc = Callable[[PyTree], Tuple[List[Any], Context]]
  75. UnflattenFunc = Callable[[Iterable[Any], Context], PyTree]
  76. OpTreeUnflattenFunc = Callable[[Context, Iterable[Any]], PyTree]
  77. DumpableContext = Any # Any json dumpable text
  78. ToDumpableContextFn = Callable[[Context], DumpableContext]
  79. FromDumpableContextFn = Callable[[DumpableContext], Context]
  80. KeyPath = Tuple[KeyEntry, ...]
  81. FlattenWithKeysFunc = Callable[[PyTree], Tuple[List[Tuple[KeyEntry, Any]], Any]]
  82. def _reverse_args(func: UnflattenFunc) -> OpTreeUnflattenFunc:
  83. @functools.wraps(func)
  84. def wrapped(*args: Any, **kwargs: Any) -> Any:
  85. return func(*reversed(args), **kwargs)
  86. return wrapped
  87. def register_pytree_node(
  88. cls: Type[Any],
  89. flatten_fn: FlattenFunc,
  90. unflatten_fn: UnflattenFunc,
  91. *,
  92. serialized_type_name: Optional[str] = None,
  93. to_dumpable_context: Optional[ToDumpableContextFn] = None,
  94. from_dumpable_context: Optional[FromDumpableContextFn] = None,
  95. flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None,
  96. ) -> None:
  97. """Register a container-like type as pytree node.
  98. Args:
  99. cls (type): A Python type to treat as an internal pytree node.
  100. flatten_fn (callable): A function to be used during flattening, taking an instance of
  101. ``cls`` and returning a pair, with (1) an iterable for the children to be flattened
  102. recursively, and (2) some hashable auxiliary data to be stored in the treespec and to be
  103. passed to the ``unflatten_fn``.
  104. unflatten_fn (callable): A function taking two arguments: the auxiliary data that was
  105. returned by ``flatten_fn`` and stored in the treespec, and the unflattened children.
  106. The function should return an instance of ``cls``.
  107. serialized_type_name (str, optional): A keyword argument used to specify the fully
  108. qualified name used when serializing the tree spec.
  109. to_dumpable_context (callable, optional): An optional keyword argument to custom specify how
  110. to convert the context of the pytree to a custom json dumpable representation. This is
  111. used for json serialization, which is being used in :mod:`torch.export` right now.
  112. from_dumpable_context (callable, optional): An optional keyword argument to custom specify
  113. how to convert the custom json dumpable representation of the context back to the
  114. original context. This is used for json deserialization, which is being used in
  115. :mod:`torch.export` right now.
  116. Example::
  117. >>> # xdoctest: +SKIP
  118. >>> # Registry a Python type with lambda functions
  119. >>> register_pytree_node(
  120. ... set,
  121. ... lambda s: (sorted(s), None, None),
  122. ... lambda children, _: set(children),
  123. ... )
  124. """
  125. if flatten_with_keys_fn is not None:
  126. raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")
  127. _private_register_pytree_node(
  128. cls,
  129. flatten_fn,
  130. unflatten_fn,
  131. serialized_type_name=serialized_type_name,
  132. to_dumpable_context=to_dumpable_context,
  133. from_dumpable_context=from_dumpable_context,
  134. )
  135. from . import _pytree as python
  136. python._private_register_pytree_node(
  137. cls,
  138. flatten_fn,
  139. unflatten_fn,
  140. serialized_type_name=serialized_type_name,
  141. to_dumpable_context=to_dumpable_context,
  142. from_dumpable_context=from_dumpable_context,
  143. )
  144. @deprecated(
  145. "`torch.utils._cxx_pytree._register_pytree_node` is deprecated. "
  146. "Please use `torch.utils._cxx_pytree.register_pytree_node` instead.",
  147. category=FutureWarning,
  148. )
  149. def _register_pytree_node(
  150. cls: Type[Any],
  151. flatten_fn: FlattenFunc,
  152. unflatten_fn: UnflattenFunc,
  153. *,
  154. serialized_type_name: Optional[str] = None,
  155. to_dumpable_context: Optional[ToDumpableContextFn] = None,
  156. from_dumpable_context: Optional[FromDumpableContextFn] = None,
  157. ) -> None:
  158. """Register a container-like type as pytree node for the C++ pytree only.
  159. The ``namespace`` argument is used to avoid collisions that occur when different libraries
  160. register the same Python type with different behaviors. It is recommended to add a unique prefix
  161. to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify
  162. the same class in different namespaces for different use cases.
  163. .. warning::
  164. For safety reasons, a ``namespace`` must be specified while registering a custom type. It is
  165. used to isolate the behavior of flattening and unflattening a pytree node type. This is to
  166. prevent accidental collisions between different libraries that may register the same type.
  167. Args:
  168. cls (type): A Python type to treat as an internal pytree node.
  169. flatten_fn (callable): A function to be used during flattening, taking an instance of
  170. ``cls`` and returning a pair, with (1) an iterable for the children to be flattened
  171. recursively, and (2) some hashable auxiliary data to be stored in the treespec and to be
  172. passed to the ``unflatten_fn``.
  173. unflatten_fn (callable): A function taking two arguments: the auxiliary data that was
  174. returned by ``flatten_fn`` and stored in the treespec, and the unflattened children.
  175. The function should return an instance of ``cls``.
  176. serialized_type_name (str, optional): A keyword argument used to specify the fully
  177. qualified name used when serializing the tree spec.
  178. to_dumpable_context (callable, optional): An optional keyword argument to custom specify how
  179. to convert the context of the pytree to a custom json dumpable representation. This is
  180. used for json serialization, which is being used in :mod:`torch.export` right now.
  181. from_dumpable_context (callable, optional): An optional keyword argument to custom specify
  182. how to convert the custom json dumpable representation of the context back to the
  183. original context. This is used for json deserialization, which is being used in
  184. :mod:`torch.export` right now.
  185. """
  186. _private_register_pytree_node(
  187. cls,
  188. flatten_fn,
  189. unflatten_fn,
  190. serialized_type_name=serialized_type_name,
  191. to_dumpable_context=to_dumpable_context,
  192. from_dumpable_context=from_dumpable_context,
  193. )
  194. def _private_register_pytree_node(
  195. cls: Type[Any],
  196. flatten_fn: FlattenFunc,
  197. unflatten_fn: UnflattenFunc,
  198. *,
  199. serialized_type_name: Optional[str] = None,
  200. to_dumpable_context: Optional[ToDumpableContextFn] = None,
  201. from_dumpable_context: Optional[FromDumpableContextFn] = None,
  202. ) -> None:
  203. """This is an internal function that is used to register a pytree node type
  204. for the C++ pytree only. End-users should use :func:`register_pytree_node`
  205. instead.
  206. """
  207. # TODO(XuehaiPan): remove this condition when we make Python pytree out-of-box support
  208. # PyStructSequence types
  209. if not optree.is_structseq_class(cls):
  210. optree.register_pytree_node(
  211. cls,
  212. flatten_fn,
  213. _reverse_args(unflatten_fn),
  214. namespace="torch",
  215. )
  216. def tree_flatten(
  217. tree: PyTree,
  218. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  219. ) -> Tuple[List[Any], TreeSpec]:
  220. """Flatten a pytree.
  221. See also :func:`tree_unflatten`.
  222. The flattening order (i.e., the order of elements in the output list) is deterministic,
  223. corresponding to a left-to-right depth-first tree traversal.
  224. >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
  225. >>> tree_flatten(tree)
  226. ([1, 2, 3, 4, None, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf))
  227. >>> tree_flatten(1)
  228. ([1], PyTreeSpec(*, NoneIsLeaf))
  229. >>> tree_flatten(None)
  230. ([None], PyTreeSpec(*, NoneIsLeaf))
  231. For unordered dictionaries, :class:`dict` and :class:`collections.defaultdict`, the order is
  232. dependent on the **sorted** keys in the dictionary. Please use :class:`collections.OrderedDict`
  233. if you want to keep the keys in the insertion order.
  234. >>> from collections import OrderedDict
  235. >>> tree = OrderedDict([('b', (2, [3, 4])), ('a', 1), ('c', None), ('d', 5)])
  236. >>> tree_flatten(tree)
  237. ([2, 3, 4, 1, None, 5], PyTreeSpec(OrderedDict([('b', (*, [*, *])), ('a', *), ('c', *), ('d', *)]), NoneIsLeaf))
  238. Args:
  239. tree (pytree): A pytree to flatten.
  240. is_leaf (callable, optional): An extra leaf predicate function that will be called at each
  241. flattening step. The function should have a single argument with signature
  242. ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
  243. as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
  244. leaf or not. If the function is not specified, the default pytree registry will be used.
  245. Returns:
  246. A pair ``(leaves, treespec)`` where the first element is a list of leaf values and the
  247. second element is a treespec representing the structure of the pytree.
  248. """
  249. return optree.tree_flatten( # type: ignore[return-value]
  250. tree,
  251. is_leaf=is_leaf,
  252. none_is_leaf=True,
  253. namespace="torch",
  254. )
  255. def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
  256. """Reconstruct a pytree from the treespec and the leaves.
  257. The inverse of :func:`tree_flatten`.
  258. >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
  259. >>> leaves, treespec = tree_flatten(tree)
  260. >>> tree == tree_unflatten(leaves, treespec)
  261. True
  262. Args:
  263. leaves (iterable): The list of leaves to use for reconstruction. The list must match the
  264. number of leaves of the treespec.
  265. treespec (TreeSpec): The treespec to reconstruct.
  266. Returns:
  267. The reconstructed pytree, containing the ``leaves`` placed in the structure described by
  268. ``treespec``.
  269. """
  270. if not isinstance(treespec, TreeSpec):
  271. raise TypeError(
  272. f"tree_unflatten(values, spec): Expected `spec` to be instance of "
  273. f"TreeSpec but got item of type {type(treespec)}."
  274. )
  275. return optree.tree_unflatten(treespec, leaves) # type: ignore[arg-type]
  276. def tree_iter(
  277. tree: PyTree,
  278. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  279. ) -> Iterable[Any]:
  280. """Get an iterator over the leaves of a pytree.
  281. See also :func:`tree_flatten`.
  282. >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
  283. >>> list(tree_iter(tree))
  284. [1, 2, 3, 4, None, 5]
  285. >>> list(tree_iter(1))
  286. [1]
  287. >>> list(tree_iter(None))
  288. [None]
  289. Args:
  290. tree (pytree): A pytree to flatten.
  291. is_leaf (callable, optional): An extra leaf predicate function that will be called at each
  292. flattening step. The function should have a single argument with signature
  293. ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
  294. as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
  295. leaf or not. If the function is not specified, the default pytree registry will be used.
  296. Returns:
  297. An iterator over the leaf values.
  298. """
  299. return optree.tree_iter(
  300. tree,
  301. is_leaf=is_leaf,
  302. none_is_leaf=True,
  303. namespace="torch",
  304. )
  305. def tree_leaves(
  306. tree: PyTree,
  307. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  308. ) -> List[Any]:
  309. """Get the leaves of a pytree.
  310. See also :func:`tree_flatten`.
  311. >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
  312. >>> tree_leaves(tree)
  313. [1, 2, 3, 4, None, 5]
  314. >>> tree_leaves(1)
  315. [1]
  316. >>> tree_leaves(None)
  317. [None]
  318. Args:
  319. tree (pytree): A pytree to flatten.
  320. is_leaf (callable, optional): An extra leaf predicate function that will be called at each
  321. flattening step. The function should have a single argument with signature
  322. ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
  323. as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
  324. leaf or not. If the function is not specified, the default pytree registry will be used.
  325. Returns:
  326. A list of leaf values.
  327. """
  328. return optree.tree_leaves(
  329. tree,
  330. is_leaf=is_leaf,
  331. none_is_leaf=True,
  332. namespace="torch",
  333. )
  334. def tree_structure(
  335. tree: PyTree,
  336. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  337. ) -> TreeSpec:
  338. """Get the treespec for a pytree.
  339. See also :func:`tree_flatten`.
  340. >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
  341. >>> tree_structure(tree)
  342. PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf)
  343. >>> tree_structure(1)
  344. PyTreeSpec(*, NoneIsLeaf)
  345. >>> tree_structure(None)
  346. PyTreeSpec(*, NoneIsLeaf)
  347. Args:
  348. tree (pytree): A pytree to flatten.
  349. is_leaf (callable, optional): An extra leaf predicate function that will be called at each
  350. flattening step. The function should have a single argument with signature
  351. ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
  352. as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
  353. leaf or not. If the function is not specified, the default pytree registry will be used.
  354. Returns:
  355. A treespec object representing the structure of the pytree.
  356. """
  357. return optree.tree_structure( # type: ignore[return-value]
  358. tree,
  359. is_leaf=is_leaf,
  360. none_is_leaf=True,
  361. namespace="torch",
  362. )
  363. def tree_map(
  364. func: Callable[..., Any],
  365. tree: PyTree,
  366. *rests: PyTree,
  367. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  368. ) -> PyTree:
  369. """Map a multi-input function over pytree args to produce a new pytree.
  370. See also :func:`tree_map_`.
  371. >>> tree_map(lambda x: x + 1, {'x': 7, 'y': (42, 64)})
  372. {'x': 8, 'y': (43, 65)}
  373. >>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None})
  374. {'x': False, 'y': (False, False), 'z': True}
  375. If multiple inputs are given, the structure of the tree is taken from the first input;
  376. subsequent inputs need only have ``tree`` as a prefix:
  377. >>> tree_map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]])
  378. [[5, 7, 9], [6, 1, 2]]
  379. Args:
  380. func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the
  381. corresponding leaves of the pytrees.
  382. tree (pytree): A pytree to be mapped over, with each leaf providing the first positional
  383. argument to function ``func``.
  384. rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as
  385. ``tree`` or has ``tree`` as a prefix.
  386. is_leaf (callable, optional): An extra leaf predicate function that will be called at each
  387. flattening step. The function should have a single argument with signature
  388. ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
  389. as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
  390. leaf or not. If the function is not specified, the default pytree registry will be used.
  391. Returns:
  392. A new pytree with the same structure as ``tree`` but with the value at each leaf given by
  393. ``func(x, *xs)`` where ``x`` is the value at the corresponding leaf in ``tree`` and ``xs``
  394. is the tuple of values at corresponding nodes in ``rests``.
  395. """
  396. return optree.tree_map(
  397. func,
  398. tree,
  399. *rests,
  400. is_leaf=is_leaf,
  401. none_is_leaf=True,
  402. namespace="torch",
  403. )
  404. def tree_map_(
  405. func: Callable[..., Any],
  406. tree: PyTree,
  407. *rests: PyTree,
  408. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  409. ) -> PyTree:
  410. """Like :func:`tree_map`, but do an inplace call on each leaf and return the original tree.
  411. See also :func:`tree_map`.
  412. Args:
  413. func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the
  414. corresponding leaves of the pytrees.
  415. tree (pytree): A pytree to be mapped over, with each leaf providing the first positional
  416. argument to function ``func``.
  417. rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as
  418. ``tree`` or has ``tree`` as a prefix.
  419. is_leaf (callable, optional): An extra leaf predicate function that will be called at each
  420. flattening step. The function should have a single argument with signature
  421. ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
  422. as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
  423. leaf or not. If the function is not specified, the default pytree registry will be used.
  424. Returns:
  425. The original ``tree`` with the value at each leaf is given by the side-effect of function
  426. ``func(x, *xs)`` (not the return value) where ``x`` is the value at the corresponding leaf
  427. in ``tree`` and ``xs`` is the tuple of values at values at corresponding nodes in ``rests``.
  428. """
  429. return optree.tree_map_(
  430. func,
  431. tree,
  432. *rests,
  433. is_leaf=is_leaf,
  434. none_is_leaf=True,
  435. namespace="torch",
  436. )
  437. Type2 = Tuple[Type[T], Type[S]]
  438. Type3 = Tuple[Type[T], Type[S], Type[U]]
  439. if sys.version_info >= (3, 10):
  440. TypeAny = Union[Type[Any], Tuple[Type[Any], ...], types.UnionType]
  441. else:
  442. TypeAny = Union[Type[Any], Tuple[Type[Any], ...]]
  443. Fn2 = Callable[[Union[T, S]], R]
  444. Fn3 = Callable[[Union[T, S, U]], R]
  445. Fn = Callable[[T], R]
  446. FnAny = Callable[[Any], R]
  447. MapOnlyFn = Callable[[T], Callable[[Any], Any]]
  448. # These specializations help with type inference on the lambda passed to this
  449. # function
  450. @overload
  451. def map_only(__type_or_types_or_pred: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]:
  452. ...
  453. @overload
  454. def map_only(__type_or_types_or_pred: Type3[T, S, U]) -> MapOnlyFn[Fn3[T, S, U, Any]]:
  455. ...
  456. @overload
  457. def map_only(__type_or_types_or_pred: Type[T]) -> MapOnlyFn[Fn[T, Any]]:
  458. ...
  459. # This specialization is needed for the implementations below that call
  460. @overload
  461. def map_only(__type_or_types_or_pred: TypeAny) -> MapOnlyFn[FnAny[Any]]:
  462. ...
  463. @overload
  464. def map_only(__type_or_types_or_pred: Callable[[Any], bool]) -> MapOnlyFn[FnAny[Any]]:
  465. ...
  466. def map_only(
  467. __type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]]
  468. ) -> MapOnlyFn[FnAny[Any]]:
  469. """
  470. Suppose you are writing a tree_map over tensors, leaving everything
  471. else unchanged. Ordinarily you would have to write:
  472. def go(t):
  473. if isinstance(t, Tensor):
  474. return ...
  475. else:
  476. return t
  477. With this function, you only need to write:
  478. @map_only(Tensor)
  479. def go(t):
  480. return ...
  481. You can also directly use 'tree_map_only'
  482. """
  483. if isinstance(__type_or_types_or_pred, (type, tuple)) or (
  484. sys.version_info >= (3, 10)
  485. and isinstance(__type_or_types_or_pred, types.UnionType)
  486. ):
  487. def pred(x: Any) -> bool:
  488. return isinstance(x, __type_or_types_or_pred) # type: ignore[arg-type]
  489. elif callable(__type_or_types_or_pred):
  490. pred = __type_or_types_or_pred # type: ignore[assignment]
  491. else:
  492. raise TypeError("Argument must be a type, a tuple of types, or a callable.")
  493. def wrapper(func: Callable[[T], Any]) -> Callable[[Any], Any]:
  494. @functools.wraps(func)
  495. def wrapped(x: T) -> Any:
  496. if pred(x):
  497. return func(x)
  498. return x
  499. return wrapped
  500. return wrapper
  501. @overload
  502. def tree_map_only(
  503. __type_or_types_or_pred: Type[T],
  504. func: Fn[T, Any],
  505. tree: PyTree,
  506. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  507. ) -> PyTree:
  508. ...
  509. @overload
  510. def tree_map_only(
  511. __type_or_types_or_pred: Type2[T, S],
  512. func: Fn2[T, S, Any],
  513. tree: PyTree,
  514. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  515. ) -> PyTree:
  516. ...
  517. @overload
  518. def tree_map_only(
  519. __type_or_types_or_pred: Type3[T, S, U],
  520. func: Fn3[T, S, U, Any],
  521. tree: PyTree,
  522. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  523. ) -> PyTree:
  524. ...
  525. @overload
  526. def tree_map_only(
  527. __type_or_types_or_pred: Callable[[Any], bool],
  528. func: FnAny[Any],
  529. tree: PyTree,
  530. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  531. ) -> PyTree:
  532. ...
  533. def tree_map_only(
  534. __type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
  535. func: FnAny[Any],
  536. tree: PyTree,
  537. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  538. ) -> PyTree:
  539. return tree_map(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
  540. @overload
  541. def tree_map_only_(
  542. __type_or_types_or_pred: Type[T],
  543. func: Fn[T, Any],
  544. tree: PyTree,
  545. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  546. ) -> PyTree:
  547. ...
  548. @overload
  549. def tree_map_only_(
  550. __type_or_types_or_pred: Type2[T, S],
  551. func: Fn2[T, S, Any],
  552. tree: PyTree,
  553. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  554. ) -> PyTree:
  555. ...
  556. @overload
  557. def tree_map_only_(
  558. __type_or_types_or_pred: Type3[T, S, U],
  559. func: Fn3[T, S, U, Any],
  560. tree: PyTree,
  561. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  562. ) -> PyTree:
  563. ...
  564. @overload
  565. def tree_map_only_(
  566. __type_or_types_or_pred: Callable[[Any], bool],
  567. func: FnAny[Any],
  568. tree: PyTree,
  569. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  570. ) -> PyTree:
  571. ...
  572. def tree_map_only_(
  573. __type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
  574. func: FnAny[Any],
  575. tree: PyTree,
  576. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  577. ) -> PyTree:
  578. return tree_map_(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
  579. def tree_all(
  580. pred: Callable[[Any], bool],
  581. tree: PyTree,
  582. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  583. ) -> bool:
  584. flat_args = tree_iter(tree, is_leaf=is_leaf)
  585. return all(map(pred, flat_args))
  586. def tree_any(
  587. pred: Callable[[Any], bool],
  588. tree: PyTree,
  589. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  590. ) -> bool:
  591. flat_args = tree_iter(tree, is_leaf=is_leaf)
  592. return any(map(pred, flat_args))
  593. @overload
  594. def tree_all_only(
  595. __type_or_types: Type[T],
  596. pred: Fn[T, bool],
  597. tree: PyTree,
  598. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  599. ) -> bool:
  600. ...
  601. @overload
  602. def tree_all_only(
  603. __type_or_types: Type2[T, S],
  604. pred: Fn2[T, S, bool],
  605. tree: PyTree,
  606. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  607. ) -> bool:
  608. ...
  609. @overload
  610. def tree_all_only(
  611. __type_or_types: Type3[T, S, U],
  612. pred: Fn3[T, S, U, bool],
  613. tree: PyTree,
  614. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  615. ) -> bool:
  616. ...
  617. def tree_all_only(
  618. __type_or_types: TypeAny,
  619. pred: FnAny[bool],
  620. tree: PyTree,
  621. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  622. ) -> bool:
  623. flat_args = tree_iter(tree, is_leaf=is_leaf)
  624. return all(pred(x) for x in flat_args if isinstance(x, __type_or_types))
  625. @overload
  626. def tree_any_only(
  627. __type_or_types: Type[T],
  628. pred: Fn[T, bool],
  629. tree: PyTree,
  630. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  631. ) -> bool:
  632. ...
  633. @overload
  634. def tree_any_only(
  635. __type_or_types: Type2[T, S],
  636. pred: Fn2[T, S, bool],
  637. tree: PyTree,
  638. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  639. ) -> bool:
  640. ...
  641. @overload
  642. def tree_any_only(
  643. __type_or_types: Type3[T, S, U],
  644. pred: Fn3[T, S, U, bool],
  645. tree: PyTree,
  646. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  647. ) -> bool:
  648. ...
  649. def tree_any_only(
  650. __type_or_types: TypeAny,
  651. pred: FnAny[bool],
  652. tree: PyTree,
  653. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  654. ) -> bool:
  655. flat_args = tree_iter(tree, is_leaf=is_leaf)
  656. return any(pred(x) for x in flat_args if isinstance(x, __type_or_types))
  657. def broadcast_prefix(
  658. prefix_tree: PyTree,
  659. full_tree: PyTree,
  660. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  661. ) -> List[Any]:
  662. """Return a list of broadcasted leaves in ``prefix_tree`` to match the number of leaves in ``full_tree``.
  663. If a ``prefix_tree`` is a prefix of a ``full_tree``, this means the ``full_tree`` can be
  664. constructed by replacing the leaves of ``prefix_tree`` with appropriate **subtrees**.
  665. This function returns a list of leaves with the same size as ``full_tree``. The leaves are
  666. replicated from ``prefix_tree``. The number of replicas is determined by the corresponding
  667. subtree in ``full_tree``.
  668. >>> broadcast_prefix(1, [1, 2, 3])
  669. [1, 1, 1]
  670. >>> broadcast_prefix([1, 2, 3], [1, 2, 3])
  671. [1, 2, 3]
  672. >>> broadcast_prefix([1, 2, 3], [1, 2, 3, 4])
  673. Traceback (most recent call last):
  674. ...
  675. ValueError: list arity mismatch; expected: 3, got: 4; list: [1, 2, 3, 4].
  676. >>> broadcast_prefix([1, 2, 3], [1, 2, (3, 4)])
  677. [1, 2, 3, 3]
  678. >>> broadcast_prefix([1, 2, 3], [1, 2, {'a': 3, 'b': 4, 'c': (None, 5)}])
  679. [1, 2, 3, 3, 3, 3]
  680. Args:
  681. prefix_tree (pytree): A pytree with the same structure as a prefix of ``full_tree``.
  682. full_tree (pytree): A pytree with the same structure as a suffix of ``prefix_tree``.
  683. is_leaf (callable, optional): An extra leaf predicate function that will be called at each
  684. flattening step. The function should have a single argument with signature
  685. ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
  686. as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
  687. leaf or not. If the function is not specified, the default pytree registry will be used.
  688. Returns:
  689. A list of leaves in ``prefix_tree`` broadcasted to match the number of leaves in ``full_tree``.
  690. """
  691. return optree.broadcast_prefix(
  692. prefix_tree,
  693. full_tree,
  694. is_leaf=is_leaf,
  695. none_is_leaf=True,
  696. namespace="torch",
  697. )
  698. # Broadcasts a pytree to the provided TreeSpec and returns the flattened
  699. # values. If this is not possible, then this function returns None.
  700. #
  701. # For example, given pytree=0 and spec=TreeSpec(list, None, [LeafSpec(), LeafSpec()]),
  702. # would return [0, 0]. This is useful for part of the vmap implementation:
  703. # a user can pass in vmap(fn, in_dims)(*inputs). `in_dims` should be
  704. # broadcastable to the tree structure of `inputs` and we use
  705. # _broadcast_to_and_flatten to check this.
  706. def _broadcast_to_and_flatten(
  707. tree: PyTree,
  708. treespec: TreeSpec,
  709. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  710. ) -> Optional[List[Any]]:
  711. assert isinstance(treespec, TreeSpec)
  712. full_tree = tree_unflatten([0] * treespec.num_leaves, treespec)
  713. try:
  714. return broadcast_prefix(tree, full_tree, is_leaf=is_leaf)
  715. except ValueError:
  716. return None
  717. def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
  718. """Serialize a treespec to a JSON string."""
  719. if not isinstance(treespec, TreeSpec):
  720. raise TypeError(
  721. f"treespec_dumps(spec): Expected `spec` to be instance of "
  722. f"TreeSpec but got item of type {type(treespec)}."
  723. )
  724. from ._pytree import (
  725. tree_structure as _tree_structure,
  726. treespec_dumps as _treespec_dumps,
  727. )
  728. orig_treespec = _tree_structure(tree_unflatten([0] * treespec.num_leaves, treespec))
  729. return _treespec_dumps(orig_treespec, protocol=protocol)
  730. def treespec_loads(serialized: str) -> TreeSpec:
  731. """Deserialize a treespec from a JSON string."""
  732. from ._pytree import (
  733. tree_unflatten as _tree_unflatten,
  734. treespec_loads as _treespec_loads,
  735. )
  736. orig_treespec = _treespec_loads(serialized)
  737. dummy_tree = _tree_unflatten([0] * orig_treespec.num_leaves, orig_treespec)
  738. treespec = tree_structure(dummy_tree)
  739. return treespec
  740. class _DummyLeaf:
  741. def __repr__(self) -> str:
  742. return "*"
  743. def treespec_pprint(treespec: TreeSpec) -> str:
  744. dummy_tree = tree_unflatten(
  745. [_DummyLeaf() for _ in range(treespec.num_leaves)],
  746. treespec,
  747. )
  748. return repr(dummy_tree)
  749. class LeafSpecMeta(type(TreeSpec)): # type: ignore[misc]
  750. def __instancecheck__(self, instance: object) -> bool:
  751. return isinstance(instance, TreeSpec) and instance.is_leaf()
  752. class LeafSpec(TreeSpec, metaclass=LeafSpecMeta):
  753. def __new__(cls) -> "LeafSpec":
  754. return optree.treespec_leaf(none_is_leaf=True) # type: ignore[return-value]
  755. def tree_flatten_with_path(
  756. tree: PyTree,
  757. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  758. ) -> Tuple[List[Tuple[KeyPath, Any]], TreeSpec]:
  759. """Flattens a pytree like :func:`tree_flatten`, but also returns each leaf's key path.
  760. Args:
  761. tree: a pytree to flatten. If it contains a custom type, that type must be
  762. registered with an appropriate `tree_flatten_with_path_fn` when registered
  763. with :func:`register_pytree_node`.
  764. is_leaf: An extra leaf predicate function that will be called at each
  765. flattening step. The function should have a single argument with signature
  766. ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
  767. as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
  768. leaf or not. If the function is not specified, the default pytree registry will be used.
  769. Returns:
  770. A tuple where the first element is a list of (key path, leaf) pairs, and the
  771. second element is a :class:`TreeSpec` representing the structure of the flattened
  772. tree.
  773. """
  774. raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")
  775. def tree_leaves_with_path(
  776. tree: PyTree,
  777. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  778. ) -> List[Tuple[KeyPath, Any]]:
  779. """Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path.
  780. Args:
  781. tree: a pytree. If it contains a custom type, that type must be
  782. registered with an appropriate `tree_flatten_with_path_fn` when registered
  783. with :func:`register_pytree_node`.
  784. is_leaf: An extra leaf predicate function that will be called at each
  785. flattening step. The function should have a single argument with signature
  786. ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
  787. as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
  788. leaf or not. If the function is not specified, the default pytree registry will be used.
  789. Returns:
  790. A list of (key path, leaf) pairs.
  791. """
  792. raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")
  793. def tree_map_with_path(
  794. func: Callable[..., Any],
  795. tree: PyTree,
  796. *rests: PyTree,
  797. is_leaf: Optional[Callable[[PyTree], bool]] = None,
  798. ) -> PyTree:
  799. """Like :func:`tree_map`, but the provided callable takes an additional key path argument.
  800. Args:
  801. func: A function that takes ``2 + len(rests)`` arguments, to be applied at the
  802. corresponding leaves of the pytrees. The first positional argument
  803. to ``func`` is the key path of the leaf in question. The second
  804. positional argument is the value of the leaf.
  805. tree: A pytree to be mapped over, with each leaf providing the first positional
  806. argument to function ``func``.
  807. rests: A tuple of pytrees, each of which has the same structure as
  808. ``tree`` or has ``tree`` as a prefix.
  809. is_leaf: An extra leaf predicate function that will be called at each
  810. flattening step. The function should have a single argument with signature
  811. ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
  812. as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
  813. leaf or not. If the function is not specified, the default pytree registry will be used.
  814. Returns
  815. A new pytree with the same structure as ``tree`` but with the value at each leaf given by
  816. ``func(keypath, x, *xs)`` where ``keypath`` is the key path at the
  817. corresponding leaf in ``tree``, ``x`` is the value at that leaf, and
  818. ``xs`` is the tuple of values at corresponding nodes in ``rests``.
  819. """
  820. raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")
  821. def keystr(kp: KeyPath) -> str:
  822. """Given a key path, return a pretty-printed representation."""
  823. raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")
  824. def key_get(obj: Any, kp: KeyPath) -> Any:
  825. """Given an object and a key path, return the value at the key path."""
  826. raise NotImplementedError("KeyPaths are not yet supported in cxx_pytree.")