immutable_collections.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. # mypy: allow-untyped-defs
  2. from typing import Any, Dict, Iterable, List, Tuple
  3. from torch.utils._pytree import (
  4. _dict_flatten,
  5. _dict_flatten_with_keys,
  6. _dict_unflatten,
  7. _list_flatten,
  8. _list_flatten_with_keys,
  9. _list_unflatten,
  10. Context,
  11. register_pytree_node,
  12. )
  13. from ._compatibility import compatibility
  14. __all__ = ["immutable_list", "immutable_dict"]
  15. _help_mutation = """\
  16. If you are attempting to modify the kwargs or args of a torch.fx.Node object,
  17. instead create a new copy of it and assign the copy to the node:
  18. new_args = ... # copy and mutate args
  19. node.args = new_args
  20. """
  21. def _no_mutation(self, *args, **kwargs):
  22. raise NotImplementedError(
  23. f"'{type(self).__name__}' object does not support mutation. {_help_mutation}",
  24. )
  25. def _create_immutable_container(base, mutable_functions):
  26. container = type("immutable_" + base.__name__, (base,), {})
  27. for attr in mutable_functions:
  28. setattr(container, attr, _no_mutation)
  29. return container
  30. immutable_list = _create_immutable_container(
  31. list,
  32. (
  33. "__delitem__",
  34. "__iadd__",
  35. "__imul__",
  36. "__setitem__",
  37. "append",
  38. "clear",
  39. "extend",
  40. "insert",
  41. "pop",
  42. "remove",
  43. "reverse",
  44. "sort",
  45. ),
  46. )
  47. immutable_list.__reduce__ = lambda self: (immutable_list, (tuple(iter(self)),))
  48. immutable_list.__hash__ = lambda self: hash(tuple(self))
  49. compatibility(is_backward_compatible=True)(immutable_list)
  50. immutable_dict = _create_immutable_container(
  51. dict,
  52. (
  53. "__delitem__",
  54. "__ior__",
  55. "__setitem__",
  56. "clear",
  57. "pop",
  58. "popitem",
  59. "setdefault",
  60. "update",
  61. ),
  62. )
  63. immutable_dict.__reduce__ = lambda self: (immutable_dict, (iter(self.items()),))
  64. immutable_dict.__hash__ = lambda self: hash(tuple(self.items()))
  65. compatibility(is_backward_compatible=True)(immutable_dict)
  66. # Register immutable collections for PyTree operations
  67. def _immutable_dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]:
  68. return _dict_flatten(d)
  69. def _immutable_dict_unflatten(
  70. values: Iterable[Any],
  71. context: Context,
  72. ) -> Dict[Any, Any]:
  73. return immutable_dict(_dict_unflatten(values, context))
  74. def _immutable_list_flatten(d: List[Any]) -> Tuple[List[Any], Context]:
  75. return _list_flatten(d)
  76. def _immutable_list_unflatten(
  77. values: Iterable[Any],
  78. context: Context,
  79. ) -> List[Any]:
  80. return immutable_list(_list_unflatten(values, context))
  81. register_pytree_node(
  82. immutable_dict,
  83. _immutable_dict_flatten,
  84. _immutable_dict_unflatten,
  85. serialized_type_name="torch.fx.immutable_collections.immutable_dict",
  86. flatten_with_keys_fn=_dict_flatten_with_keys,
  87. )
  88. register_pytree_node(
  89. immutable_list,
  90. _immutable_list_flatten,
  91. _immutable_list_unflatten,
  92. serialized_type_name="torch.fx.immutable_collections.immutable_list",
  93. flatten_with_keys_fn=_list_flatten_with_keys,
  94. )