dim.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import dis
  7. import inspect
  8. from dataclasses import dataclass
  9. from typing import Union
  10. from . import DimList
  11. _vmap_levels = []
  12. @dataclass
  13. class LevelInfo:
  14. level: int
  15. alive: bool = True
  16. class Dim:
  17. def __init__(self, name: str, size: Union[None, int] = None):
  18. self.name = name
  19. self._size = None
  20. self._vmap_level = None
  21. if size is not None:
  22. self.size = size
  23. def __del__(self):
  24. if self._vmap_level is not None:
  25. _vmap_active_levels[self._vmap_stack].alive = False # noqa: F821
  26. while (
  27. not _vmap_levels[-1].alive
  28. and current_level() == _vmap_levels[-1].level # noqa: F821
  29. ):
  30. _vmap_decrement_nesting() # noqa: F821
  31. _vmap_levels.pop()
  32. @property
  33. def size(self):
  34. assert self.is_bound
  35. return self._size
  36. @size.setter
  37. def size(self, size: int):
  38. from . import DimensionBindError
  39. if self._size is None:
  40. self._size = size
  41. self._vmap_level = _vmap_increment_nesting(size, "same") # noqa: F821
  42. self._vmap_stack = len(_vmap_levels)
  43. _vmap_levels.append(LevelInfo(self._vmap_level))
  44. elif self._size != size:
  45. raise DimensionBindError(
  46. f"Dim '{self}' previously bound to a dimension of size {self._size} cannot bind to a dimension of size {size}"
  47. )
  48. @property
  49. def is_bound(self):
  50. return self._size is not None
  51. def __repr__(self):
  52. return self.name
  53. def extract_name(inst):
  54. assert inst.opname == "STORE_FAST" or inst.opname == "STORE_NAME"
  55. return inst.argval
  56. _cache = {}
  57. def dims(lists=0):
  58. frame = inspect.currentframe()
  59. assert frame is not None
  60. calling_frame = frame.f_back
  61. assert calling_frame is not None
  62. code, lasti = calling_frame.f_code, calling_frame.f_lasti
  63. key = (code, lasti)
  64. if key not in _cache:
  65. first = lasti // 2 + 1
  66. instructions = list(dis.get_instructions(calling_frame.f_code))
  67. unpack = instructions[first]
  68. if unpack.opname == "STORE_FAST" or unpack.opname == "STORE_NAME":
  69. # just a single dim, not a list
  70. name = unpack.argval
  71. ctor = Dim if lists == 0 else DimList
  72. _cache[key] = lambda: ctor(name=name)
  73. else:
  74. assert unpack.opname == "UNPACK_SEQUENCE"
  75. ndims = unpack.argval
  76. names = tuple(
  77. extract_name(instructions[first + 1 + i]) for i in range(ndims)
  78. )
  79. first_list = len(names) - lists
  80. _cache[key] = lambda: tuple(
  81. Dim(n) if i < first_list else DimList(name=n)
  82. for i, n in enumerate(names)
  83. )
  84. return _cache[key]()
  85. def _dim_set(positional, arg):
  86. def convert(a):
  87. if isinstance(a, Dim):
  88. return a
  89. else:
  90. assert isinstance(a, int)
  91. return positional[a]
  92. if arg is None:
  93. return positional
  94. elif not isinstance(arg, (Dim, int)):
  95. return tuple(convert(a) for a in arg)
  96. else:
  97. return (convert(arg),)