| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121 |
- # Copyright (c) Facebook, Inc. and its affiliates.
- # All rights reserved.
- #
- # This source code is licensed under the BSD-style license found in the
- # LICENSE file in the root directory of this source tree.
- import dis
- import inspect
- from dataclasses import dataclass
- from typing import Union
- from . import DimList
- _vmap_levels = []
- @dataclass
- class LevelInfo:
- level: int
- alive: bool = True
- class Dim:
- def __init__(self, name: str, size: Union[None, int] = None):
- self.name = name
- self._size = None
- self._vmap_level = None
- if size is not None:
- self.size = size
- def __del__(self):
- if self._vmap_level is not None:
- _vmap_active_levels[self._vmap_stack].alive = False # noqa: F821
- while (
- not _vmap_levels[-1].alive
- and current_level() == _vmap_levels[-1].level # noqa: F821
- ):
- _vmap_decrement_nesting() # noqa: F821
- _vmap_levels.pop()
- @property
- def size(self):
- assert self.is_bound
- return self._size
- @size.setter
- def size(self, size: int):
- from . import DimensionBindError
- if self._size is None:
- self._size = size
- self._vmap_level = _vmap_increment_nesting(size, "same") # noqa: F821
- self._vmap_stack = len(_vmap_levels)
- _vmap_levels.append(LevelInfo(self._vmap_level))
- elif self._size != size:
- raise DimensionBindError(
- f"Dim '{self}' previously bound to a dimension of size {self._size} cannot bind to a dimension of size {size}"
- )
- @property
- def is_bound(self):
- return self._size is not None
- def __repr__(self):
- return self.name
- def extract_name(inst):
- assert inst.opname == "STORE_FAST" or inst.opname == "STORE_NAME"
- return inst.argval
- _cache = {}
- def dims(lists=0):
- frame = inspect.currentframe()
- assert frame is not None
- calling_frame = frame.f_back
- assert calling_frame is not None
- code, lasti = calling_frame.f_code, calling_frame.f_lasti
- key = (code, lasti)
- if key not in _cache:
- first = lasti // 2 + 1
- instructions = list(dis.get_instructions(calling_frame.f_code))
- unpack = instructions[first]
- if unpack.opname == "STORE_FAST" or unpack.opname == "STORE_NAME":
- # just a single dim, not a list
- name = unpack.argval
- ctor = Dim if lists == 0 else DimList
- _cache[key] = lambda: ctor(name=name)
- else:
- assert unpack.opname == "UNPACK_SEQUENCE"
- ndims = unpack.argval
- names = tuple(
- extract_name(instructions[first + 1 + i]) for i in range(ndims)
- )
- first_list = len(names) - lists
- _cache[key] = lambda: tuple(
- Dim(n) if i < first_list else DimList(name=n)
- for i, n in enumerate(names)
- )
- return _cache[key]()
- def _dim_set(positional, arg):
- def convert(a):
- if isinstance(a, Dim):
- return a
- else:
- assert isinstance(a, int)
- return positional[a]
- if arg is None:
- return positional
- elif not isinstance(arg, (Dim, int)):
- return tuple(convert(a) for a in arg)
- else:
- return (convert(arg),)
|