| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638 |
- # mypy: allow-untyped-defs
- from __future__ import annotations
- import contextlib
- import dataclasses
- import warnings
- import weakref
- from dataclasses import dataclass
- from typing import (
- Any,
- Callable,
- ClassVar,
- ContextManager,
- Dict,
- List,
- Optional,
- Tuple,
- Type,
- TYPE_CHECKING,
- Union,
- )
- from typing_extensions import TypeAlias
- import torch
- from torch._C._autograd import CreationMeta
- from torch._C._functorch import (
- _add_batch_dim,
- _unwrap_functional_tensor,
- _wrap_functional_tensor,
- get_unwrapped,
- is_batchedtensor,
- is_functorch_wrapped_tensor,
- is_gradtrackingtensor,
- is_legacy_batchedtensor,
- maybe_get_bdim,
- maybe_get_level,
- peek_interpreter_stack,
- )
- from torch._logging import trace_structured
- from torch.utils._mode_utils import no_dispatch
- from torch.utils._python_dispatch import is_traceable_wrapper_subclass
- from torch.utils.weak import WeakIdKeyDictionary
- if TYPE_CHECKING:
- from torch._C._functorch import CInterpreter
- from torch._guards import Source
- # Import here to avoid cycle
- from torch._subclasses.fake_tensor import FakeTensorMode
- # Import the following modules during type checking to enable code intelligence features,
- # Do not import unconditionally, as they import sympy and importing sympy is very slow
- from torch.fx.experimental.symbolic_shapes import ShapeEnv, SymbolicContext
- DimList = List
- def safe_is_leaf(t):
- try:
- return t.is_leaf
- except RuntimeError:
- # inference mode can trigger this
- return False
- def safe_grad(t):
- with warnings.catch_warnings():
- warnings.filterwarnings("ignore", "The .grad attribute of a Tensor")
- return t.grad
- def assert_eq(a, b):
- assert a == b, f"{a} != {b}"
- def assert_metadata_eq(
- assert_eq,
- m1: Union[MetaTensorDesc, torch.Tensor],
- m2: torch.Tensor,
- *,
- skip_symbolic=False,
- skip_leaf=False,
- ):
- if isinstance(m1, torch.Tensor):
- m1 = MetaTensorDescriber().describe_tensor(m1)
- def go(m1, m2):
- assert_eq(m1.dtype, m2.dtype)
- if not skip_symbolic:
- assert_eq(m1.shape, m2.shape)
- assert_eq(m1.requires_grad, m2.requires_grad)
- if not skip_leaf:
- assert_eq(m1.is_leaf, m2.is_leaf)
- # MetaTensorDesc doesn't store grad_fn; inferred from leaf
- # assert_eq(m1.grad_fn is None, m2.grad_fn is None)
- assert_eq(m1.is_sparse, m2.is_sparse)
- assert_eq(m1.is_inference, m2.is_inference())
- assert_eq(m1.is_conj, m2.is_conj())
- assert_eq(m1.is_neg, m2.is_neg())
- assert_eq(m1.grad is not None, safe_grad(m2) is not None)
- if m1.grad is not None:
- go(m1.grad, safe_grad(m2))
- if m1.is_sparse:
- assert_eq(m1.dense_dim, m2.dense_dim())
- assert_eq(m1.sparse_dim, m2.sparse_dim())
- assert_eq(m1.is_coalesced, m2.is_coalesced())
- else:
- if not skip_symbolic:
- assert_eq(m1.stride, m2.stride())
- assert_eq(m1.storage_offset, m2.storage_offset())
- assert_eq(m1.is_view, m2._is_view())
- if m1.is_view:
- go(m1.base, m2._base)
- # TODO: test if is resizable (no direct query for this atm)
- # TODO: audit AutogradMeta to see if it matches
- # TODO: test forward AD
- return go(m1, m2)
- def is_sparse_coo(t):
- return isinstance(t, torch.Tensor) and t.layout is torch.sparse_coo
- def is_sparse_compressed_layout(layout):
- return layout in {
- torch.sparse_csr,
- torch.sparse_csc,
- torch.sparse_bsr,
- torch.sparse_bsc,
- }
- def is_sparse_compressed(t):
- return isinstance(t, torch.Tensor) and is_sparse_compressed_layout(t.layout)
- def is_sparse_any(t):
- return is_sparse_coo(t) or is_sparse_compressed(t)
- # Don't use id() directly, because those can get reallocated over time.
- MetaStorageId: TypeAlias = int
- MetaTensorId: TypeAlias = int
- DESCRIBER_NEXT_ID = 0
- class MetaTensorDescriber:
- """
- Given a Tensor/Storage, generate a MetaTensorDesc/MetaStorageDesc
- for it, which is enough information to reconstruct a meta tensor/fake tensor
- corresponding to a Tensor as faithfully as possible.
- This is a stateful conversion object because we keep track of the IDs
- of the tensors/storages passed to us, so we can consistently give
- the same ID when we see the same tensor/storage.
- """
- def __init__(self, *, copy_data=False):
- global DESCRIBER_NEXT_ID
- self.id = DESCRIBER_NEXT_ID
- DESCRIBER_NEXT_ID += 1
- self.next_tensor_id: MetaTensorId = 0
- self.next_storage_id: MetaStorageId = 0
- # Tensor -> int
- self.lookup_tensor = WeakIdKeyDictionary()
- # Storage -> int
- self.lookup_storage = WeakIdKeyDictionary()
- self.copy_data = copy_data
- self.traced_tensors = set()
- self.traced_storages = set()
- def get_tensor_id(self, t: torch.Tensor):
- if t not in self.lookup_tensor:
- self.lookup_tensor[t] = self.next_tensor_id
- self.next_tensor_id += 1
- return self.lookup_tensor[t]
- def get_storage_id(self, s: torch.UntypedStorage):
- if s not in self.lookup_storage:
- self.lookup_storage[s] = self.next_storage_id
- self.next_storage_id += 1
- return self.lookup_storage[s]
- def describe_storage(self, s: torch.UntypedStorage, *, trace: bool = False):
- r = MetaStorageDesc(
- id=self.get_storage_id(s),
- size=s.size(),
- # NB: We don't do the copy yet; copy happens when we start
- # creating the new storages
- data=s if self.copy_data else None,
- )
- if trace and r.id not in self.traced_storages:
- trace_structured(
- "describe_storage",
- metadata_fn=lambda: r.as_json(self.id),
- )
- self.traced_storages.add(r.id)
- return r
- def describe_tensor(
- self, t: torch.Tensor, *, recurse: bool = True, trace: bool = False
- ):
- is_leaf = safe_is_leaf(t)
- is_view = t._is_view()
- is_sparse = t.is_sparse
- layout = t.layout
- is_nested = t.is_nested
- is_traceable_wrapper_subclass_v = is_traceable_wrapper_subclass(t)
- is_functorch_wrapped = is_functorch_wrapped_tensor(t)
- is_mkldnn = t.is_mkldnn
- is_batchedtensor_v = is_batchedtensor(t)
- is_legacy_batchedtensor_v = is_legacy_batchedtensor(t)
- is_gradtrackingtensor_v = is_gradtrackingtensor(t)
- is_functorch_batched_or_grad = is_batchedtensor_v or is_gradtrackingtensor_v
- is_functional = torch._is_functional_tensor(t)
- storage = None
- # NB: For compatibility, I default this to zero, as sometimes people
- # still have stuffed zero into storage offset even though the tensor
- # doesn't meaningfully have an offset
- storage_offset = 0
- if not (
- is_sparse
- or is_sparse_compressed_layout(layout)
- or (is_nested and not is_traceable_wrapper_subclass_v)
- or is_mkldnn
- # TODO: TBH, functorch wrapped tensors probably should have
- # storage associated with them
- or is_functorch_wrapped
- or is_legacy_batchedtensor_v
- ):
- # NB: We actually don't use storage to do views, but might as well
- # put it in for accuracy
- storage = self.describe_storage(t.untyped_storage(), trace=trace)
- storage_offset = t.storage_offset()
- stride = None
- if not (
- is_sparse
- or is_sparse_compressed_layout(layout)
- or (is_nested and not is_traceable_wrapper_subclass_v)
- ):
- # stride/storage_offset are called from is_functorch_wrapped,
- # view_from_base, empty_create_subclass,
- # sym_sizes_strides_storage_offset (empty_create)
- stride = t.stride()
- # NB: this technically should refer to functorch unwrapped tensor, but
- # I am (perhaps abusively) using it to store both the functorch and
- # non-functorch functional tensor
- unwrapped = None
- autograd_meta_from = None
- current_level = None
- if is_batchedtensor_v or is_gradtrackingtensor_v:
- unwrapped = self.describe_tensor(get_unwrapped(t), trace=trace)
- # xla and lazy tensors present as functional tensors, but we want them
- # to be handled specially
- elif is_functional and t.device.type not in ("xla", "lazy"):
- if t._is_view():
- raise RuntimeError(
- "Cannot safely fakify a view because this process drops the view information right now."
- )
- if not is_functorch_wrapped:
- torch._sync(t)
- unwrapped = self.describe_tensor(
- torch._from_functional_tensor(t), trace=trace
- )
- autograd_meta_from = t
- else:
- reapply_views = torch._C._functionalization_reapply_views_tls()
- # NB: has side effects!
- unwrapped = self.describe_tensor(
- _unwrap_functional_tensor(t, reapply_views), trace=trace
- )
- # TODO: It's pretty suspicious that functional tensors don't have
- # valid level and thus we just grab whatever the current level
- # is
- current_level = torch._C._functorch.current_level()
- maybe_functorch_stack = None
- if is_functorch_wrapped:
- with torch._functorch.pyfunctorch.temporarily_clear_interpreter_stack() as maybe_functorch_stack:
- pass
- attrs = None
- ctx = None
- type_v = None
- if is_traceable_wrapper_subclass_v:
- assert hasattr(t, "__tensor_flatten__")
- raw_attrs, ctx = t.__tensor_flatten__()
- attrs = {
- attr: self.describe_tensor(getattr(t, attr), trace=trace)
- for attr in raw_attrs
- }
- type_v = type(t)
- # TODO: Is it important to enable torch.inference_mode before querying
- # these values?
- r = MetaTensorDesc(
- id=self.get_tensor_id(t),
- storage=storage,
- is_inference=t.is_inference(),
- is_leaf=is_leaf,
- requires_grad=t.requires_grad,
- # NB: ndim should be OK too but there is a disaster at
- # python test/dynamo/test_subclasses.py -k test_user_overidden_property_unsupported
- # Actually, this means that we have a little bit of a problem
- # here, which is that there is some sensitivity to how exactly an
- # access is done if you have a __torch_function__ subclass. Maybe
- # should disable torch function before doing accesses?
- ndim=t.dim(),
- dtype=t.dtype,
- is_sparse=is_sparse,
- is_mkldnn=is_mkldnn,
- is_functorch_wrapped=is_functorch_wrapped,
- is_batchedtensor=is_batchedtensor_v,
- is_legacy_batchedtensor=is_legacy_batchedtensor_v,
- is_gradtrackingtensor=is_gradtrackingtensor_v,
- is_view=is_view,
- is_conj=t.is_conj(),
- is_neg=t.is_neg(),
- is_parameter=isinstance(t, torch.nn.Parameter),
- is_traceable_wrapper_subclass=is_traceable_wrapper_subclass_v,
- is_nested=is_nested,
- is_functional=is_functional,
- layout=layout,
- device=t.device,
- size=t.size(),
- stride=stride,
- storage_offset=storage_offset,
- dynamo_dynamic_indices=list(getattr(t, "_dynamo_dynamic_indices", set())),
- sparse_dim=t.sparse_dim()
- if t.is_sparse or is_sparse_compressed(t)
- else None,
- dense_dim=t.dense_dim() if t.is_sparse or is_sparse_compressed(t) else None,
- is_coalesced=t.is_coalesced() if t.is_sparse else None,
- # TODO: I actually think recursing here is correct, but we have at
- # least an infinite cycle from base -> values -> base
- # https://github.com/pytorch/pytorch/issues/122089
- crow_indices=self.describe_tensor(
- t.crow_indices(), recurse=False, trace=trace
- )
- if recurse and t.layout in {torch.sparse_csr, torch.sparse_bsr}
- else None,
- col_indices=self.describe_tensor(
- t.col_indices(), recurse=False, trace=trace
- )
- if recurse and t.layout in {torch.sparse_csr, torch.sparse_bsr}
- else None,
- ccol_indices=self.describe_tensor(
- t.ccol_indices(), recurse=False, trace=trace
- )
- if recurse and t.layout in {torch.sparse_csc, torch.sparse_bsc}
- else None,
- row_indices=self.describe_tensor(
- t.row_indices(), recurse=False, trace=trace
- )
- if recurse and t.layout in {torch.sparse_csc, torch.sparse_bsc}
- else None,
- values=self.describe_tensor(t.values(), recurse=False, trace=trace)
- if recurse and is_sparse_compressed(t)
- else None,
- grad=self.describe_tensor(safe_grad(t), trace=trace)
- if safe_grad(t) is not None
- else None,
- creation_meta=torch._C._autograd._get_creation_meta(t)
- if t._is_view()
- else None,
- unwrapped=unwrapped,
- level=maybe_get_level(t)
- if is_batchedtensor_v or is_gradtrackingtensor_v
- else None,
- bdim=maybe_get_bdim(t) if is_batchedtensor_v else None,
- base=self.describe_tensor(t._base, trace=trace)
- if recurse and t._is_view() and t._base is not None
- else None,
- fake_mode=torch._subclasses.fake_tensor.maybe_get_fake_mode(t),
- view_func=t._view_func_unsafe,
- attrs=attrs,
- ctx=ctx,
- type=type_v,
- # NB: even if functorch is enabled, don't actually save the
- # interpreter stack here unless we are actually functorch wrapped;
- # it's irrelevant for non-functorch stuff
- functorch_stack=maybe_functorch_stack,
- autograd_meta_from=autograd_meta_from,
- current_level=current_level,
- data=t if self.copy_data else None,
- )
- if trace and r.id not in self.traced_tensors:
- trace_structured(
- "describe_tensor",
- metadata_fn=lambda: r.as_json(self.id),
- )
- self.traced_tensors.add(r.id)
- return r
- @dataclass(frozen=True)
- class MetaStorageDesc:
- id: MetaStorageId
- size: int
- # NB: this is only populated with copy_data True, it is not directly
- # serializable in JSON, you want to do something special here anyway
- data: Optional[torch.UntypedStorage]
- def as_json(self, describer_id):
- return {
- "id": self.id,
- "describer_id": describer_id,
- "size": self.size if isinstance(self.size, int) else repr(self.size),
- }
- @dataclass(frozen=True)
- class MetaTensorDesc:
- id: MetaTensorId
- ndim: int
- dtype: torch.dtype
- device: torch.device
- # NB: Sometimes, size, stride and storage_offset contain SymInt, in which
- # case this is NOT serializable. That only happens when you're
- # re-fakeifying a fake tensor with an existing ShapeEnv... maybe we
- # can get rid of this use case entirely. Notably, even if we are
- # fakeifying a real tensor into a fake tensor with symbolic shapes, the
- # size here is NOT dynamic
- # NB: These also contain SymInt because wrap_meta_outputs_with_default_device_logic
- # goes through this codepath. But it really should not LOL.
- # NB: size could potentially be None as you can override it and make it
- # throw an error, but we don't currently have any subclasses that do this
- # except C++ nested tensor but we're going to have nested int to make this
- # defined on NJT
- size: Tuple[int, ...]
- dynamo_dynamic_indices: List[int]
- layout: torch.layout = torch.strided
- is_inference: bool = False
- is_leaf: bool = False
- requires_grad: bool = False
- is_sparse: bool = False
- is_mkldnn: bool = False
- is_functorch_wrapped: bool = False
- is_batchedtensor: bool = False
- is_legacy_batchedtensor: bool = False
- is_gradtrackingtensor: bool = False
- is_view: bool = False
- is_nested: bool = False
- is_traceable_wrapper_subclass: bool = False
- is_functional: bool = False
- is_conj: bool = False
- is_neg: bool = False
- is_parameter: bool = False
- stride: Optional[Tuple[int, ...]] = None
- storage_offset: int = 0
- # NB: We have a choice whether or not to store the id or a direct pointer
- # to the data structure. For ease of use, we store the data structure,
- # but this means that when we serialize, we have to swizzle these pointers
- # back into ids (so we have accurate aliasing relationships)
- storage: Optional[MetaStorageDesc] = None
- sparse_dim: Optional[int] = None # is_sparse, is_sparse_compressed
- dense_dim: Optional[int] = None # is_sparse, is_sparse_compressed
- is_coalesced: Optional[bool] = None # is_sparse
- crow_indices: Optional[MetaTensorDesc] = None # is_sparse_compressed
- col_indices: Optional[MetaTensorDesc] = None # is_sparse_compressed
- ccol_indices: Optional[MetaTensorDesc] = None # is_sparse_compressed
- row_indices: Optional[MetaTensorDesc] = None # is_sparse_compressed
- values: Optional[MetaTensorDesc] = None # is_sparse_compressed
- unwrapped: Optional[MetaTensorDesc] = None # is_functorch_wrapped
- bdim: Optional[int] = None # is_functorch_wrapped
- base: Optional[MetaTensorDesc] = None # is_view
- attrs: Optional[Dict[str, MetaTensorDesc]] = None # is_traceable_wrapper_subclass
- creation_meta: Optional[CreationMeta] = None
- grad: Optional[MetaTensorDesc] = None
- # Everything below is NOT serializable, need some more work
- _UNSERIALIZABLE: ClassVar[List[str]] = [
- "ctx",
- "type",
- "fake_mode",
- "view_func",
- "level",
- "current_level",
- "functorch_stack",
- "autograd_meta_from",
- "data",
- ]
- ctx: Optional[object] = None # is_traceable_wrapper_subclass
- type: Optional[Type] = None # is_traceable_wrapper_subclass
- fake_mode: Optional[FakeTensorMode] = None
- view_func: Optional[
- Callable[
- [
- torch.Tensor,
- Callable[[int], int],
- Callable[[torch.Tensor], torch.Tensor],
- ],
- torch.Tensor,
- ]
- ] = None
- # level looks serializable, but actually it is meaningless without
- # the functorch_stack below
- level: Optional[int] = None # is_functorch_wrapped
- current_level: Optional[int] = None
- functorch_stack: Optional[List[CInterpreter]] = None
- autograd_meta_from: Optional[torch.Tensor] = None
- # This is only populated on copy_data, and typically is not used at all,
- # except for some of our meta-ification paths that don't properly use
- # storage (pro-tip: you should use storage)
- data: Optional[torch.Tensor] = None
- # Faithfully serializing functorch tensors will not be too difficult.
- # We only need to consider grad/vmap interpreters, and their internal
- # state is only bools (mostly what the grad enabled/disabled state
- # should be in the lower layer). Beyond that, tensors just need to
- # precisely indicate which particular interpreter they correspond
- # to (we then replace level with a pointer to the interpreter stack.)
- # However, this use of functorch is very "non-lexical" so it's not
- # entirely clear how to make it all lexical again, so we haven't done
- # it for now.
- # NB: This will reference numeric IDs, and it is assumed that you've
- # already serialized everything this recursively references
- def as_json(self, describer_id):
- def json(k, v):
- # Some best-effort debugging serialization for unserializable
- # fields (feel free to add other special cases as appropriate)
- if k in ["data", "autograd_meta_from"]:
- return None # never repr these
- if k in set(MetaTensorDesc._UNSERIALIZABLE):
- return repr(v)
- if isinstance(v, (torch.device, torch.dtype, torch.layout)):
- return repr(v)
- if isinstance(v, torch.SymInt):
- return repr(v)
- if isinstance(v, (tuple, list)):
- return [json(k, v1) for v1 in v]
- if isinstance(v, (MetaStorageDesc, MetaTensorDesc)):
- return v.id
- if isinstance(v, CreationMeta):
- return str(v)
- if k == "attrs" and isinstance(v, dict):
- return {k1: v1.id for k1, v1 in v.items()}
- return v
- r = {
- field.name: json(field.name, getattr(self, field.name))
- for field in dataclasses.fields(self)
- if not (
- getattr(self, field.name) is field.default
- or (
- field.name == "dynamo_dynamic_indices"
- and not getattr(self, field.name)
- )
- )
- }
- r.update({"describer_id": describer_id})
- return r
- @property
- def shape(self):
- return self.size
- # A more faithful reproduction would do a copy on the entire
- # storage, but this needs to be done carefully because the
- # underlying storage could have larger extent than is implied
- # by size/stride. The real fix is to properly call
- # meta_storage recursively here.
- #
- # These "safe" functions are intended to be used under no_dispatch() mode.
- # The no_dispatch() here is intended to prevent ambient fake tensor mode from
- # fakeifying the operation. But if we are given an honest to goodness
- # FakeTensor as src, we MUST NOT run the copy/clone operation. A better way
- # to do this would be to not use no_dispatch and instead just disable fake
- # tensor mode only (allowing for subclass dispatch to occur)
- def _safe_copy(dst, src):
- if type(src) is not torch.Tensor:
- return
- dst.copy_(src)
- def _safe_clone(src):
- if type(src) is not torch.Tensor:
- return None
- return src.clone()
- # This is a class for converting multiple tensors into meta tensors which
- # share the same view/storage structure. The operation model is you allocate
- # one of these, and then call it repeatedly on all the tensors you want to
- # convert. It's important to use the same object for tensors you want to
- # share storage because this is how we correlate shared storages to the same
- # meta storages. This class will hold weak references to cached tenosrs
- # and tensor storages.
- class MetaConverter:
- def __init__(self, *, copy_data: bool = False):
- # Maps MetaStorageId to UntypedStorage
- self.storage_memo: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
- # Maps MetaTensorId to torch.Tensor (typically a meta tensor or
- # FakeTensor)
- self.tensor_memo: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
- self.hit = 0
- self.miss = 0
- self.del_hook = None
- self.arg_cnt = 0
- # Ensures real_storage/real_tensor are populated on the resulting
- # metaified storage/tensor. The naming of this attribute is load
- # bearing: FakeTensor relies on real tensor being set to exactly this
- # value
- self.copy_data = copy_data
- self.describer = MetaTensorDescriber(copy_data=copy_data)
- def successful(self):
- return self.hit > 0 and self.miss == 0
- def get_tensor_memo(self, t: MetaTensorDesc):
- return self.tensor_memo.get(t.id, None)
- def set_tensor_memo(self, t: MetaTensorDesc, v):
- self.tensor_memo[t.id] = v
- def get_storage_memo(self, s: MetaStorageDesc):
- return self.storage_memo.get(s.id, None)
- def set_storage_memo(self, s: MetaStorageDesc, v):
- self.storage_memo[s.id] = v
- def meta_storage(self, s: MetaStorageDesc, callback):
- # If we are fakeifying a tensor that has a secretly-zero-sized storage,
- # Need to make sure to resize the meta storage too.
- if self.get_storage_memo(s) is None:
- r_s = callback(
- lambda: torch.empty(s.size, dtype=torch.uint8, device="meta"),
- ).untyped_storage()
- if self.copy_data:
- # NB: no_dispatch is needed because internally storage copy is
- # implemented as Tensor operations
- with torch.no_grad(), no_dispatch():
- assert s.data is not None
- r_s.real_storage = s.data.clone()
- self.set_storage_memo(s, r_s)
- return r_s
- else:
- return self.get_storage_memo(s)
- # This function assumes that it's possible to do the conversion
- # NB: name here is used in a conventional way by Dynamo; it corresponds
- # precisely to the Source.name() of the tensor we're fakeifying and
- # corresponds to a valid Python expression. When we construct sub-names
- # as part of this process, we will maintain this invariant! (Even though
- # other users of this may not need it this property to be upheld.)
- def meta_tensor(
- self,
- t: MetaTensorDesc,
- shape_env: Optional[ShapeEnv] = None,
- callback=lambda t: t(),
- source: Optional[Source] = None,
- symbolic_context: Optional[SymbolicContext] = None,
- ):
- if source is None:
- from torch._dynamo.source import ConstantSource
- # TODO: make a dedicated UnknownSource for this?
- source = ConstantSource(
- f"__meta_utils_unknown_tensor{len(self.tensor_memo)}"
- )
- # This indicates you set no_dispatch() before calling into this
- # function. This is an error: we may be creating fake tensors and
- # will perform operations on them which need fake tensor mode to
- # be active. You will segfault if you are in a no_dispatch() block.
- assert not torch._C._dispatch_tls_local_exclude_set().has(
- torch._C.DispatchKey.Python
- )
- arg_cnt = self.arg_cnt
- self.arg_cnt += 1
- # When we make as_strided calls, we end up generating a guard
- # that the new as_strided tensor is in bounds for the old storage
- # for the base (since as_strided calls can "bust" out of their
- # bounding box.) This guard is unnecessary: if a user is able
- # to provide us a tensor with the view base setup this way, we
- # don't need to produce a guard, because the fact that they
- # were able to produce the view base means its in bounds.
- #
- # Now, ordinarily, this guard would be harmless. However, the
- # generated guard refers to variables bound on the base variable.
- # At the moment, Dynamo doesn't actually guard on x._base, because
- # according to Voz this results in a lot of spurious invalidations,
- # and also if the user doesn't directly make use of _base, its
- # pointless anyway (because programs should be parametric over
- # whether or not the input tensor is a view or not--unless you're
- # mutating the input, but that's a whole 'nother ballgame). So
- # for expediency, we suppress these guards so we don't have to
- # deal with this (yet, anyway.)
- #
- # NB: An old version of this code suppressed guards for ALL operations
- # happening during meta conversion, not just as_strided calls.
- # This is too aggressive: we do duck sizing and 0/1 simplification
- # as we allocate variables, and we do need to register guards for
- # these cases.
- maybe_suppress: Callable[[], Any] = contextlib.nullcontext
- if shape_env is not None:
- maybe_suppress = shape_env.suppress_guards
- def sym_sizes_strides_storage_offset(
- t: MetaTensorDesc, src, symbolic_context=symbolic_context
- ) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]:
- assert t.stride is not None
- if shape_env is not None:
- fake_mode = t.fake_mode
- if fake_mode is not None and fake_mode.shape_env is shape_env:
- # Don't reallocate the sizes; the shape envs are the same,
- # so reuse the old sizes/strides/etc
- return (t.size, t.stride, t.storage_offset)
- else:
- # TODO: deduplicate this
- t_size = tuple(
- shape_env._maybe_specialize_sym_int_with_hint(sz)
- for sz in t.size
- )
- t_stride = tuple(
- shape_env._maybe_specialize_sym_int_with_hint(sd)
- for sd in t.stride
- )
- t_storage_offset = shape_env._maybe_specialize_sym_int_with_hint(
- t.storage_offset
- )
- return shape_env._create_symbolic_sizes_strides_storage_offset(
- t_size,
- t_stride,
- t_storage_offset,
- [d in t.dynamo_dynamic_indices for d in range(t.ndim)],
- src,
- symbolic_context=symbolic_context,
- )
- else:
- return (t.size, t.stride, t.storage_offset)
- def empty_create(
- inner_t: MetaTensorDesc, inner_src, symbolic_context=symbolic_context
- ):
- (
- inner_sizes,
- inner_strides,
- inner_storage_offset,
- ) = sym_sizes_strides_storage_offset(inner_t, inner_src, symbolic_context)
- return torch.empty_strided(
- inner_sizes,
- inner_strides,
- dtype=inner_t.dtype,
- device="meta",
- )
- # Creates a subclass instance with empty inner tensors according to the specified
- # symbolic context.
- def empty_create_subclass(
- t: MetaTensorDesc,
- outer_size,
- outer_stride,
- symbolic_context=symbolic_context,
- callback=callback,
- source=source,
- ):
- from torch._dynamo.source import AttrSource
- from torch.fx.experimental.symbolic_shapes import SubclassSymbolicContext
- assert t.attrs is not None
- assert t.type is not None
- # NB: t.ctx could be None if the subclass in question has no
- # meaningful context
- assert symbolic_context is None or isinstance(
- symbolic_context, SubclassSymbolicContext
- )
- # Note: transform_subclass will use __tensor_unflatten__ to generate
- # a fresh subclass wrapper with outer sizes / strides according to the
- # outer symbolic context (passed in to this function). Inner size / stride
- # / storage offset symbols are allocated according to the appropriate inner
- # symbolic contexts, after which the checks in transform_subclass() will
- # relate them to the outer metadata as possible.
- #
- # Morally, the code here is same as transform_subclass, but we've
- # written it from scratch to read EmptyCreateSubclass
- outer_size = outer_size if outer_size is not None else t.size
- outer_stride = outer_stride if outer_stride is not None else t.stride
- def transform(attr, inner_t):
- r = callback(
- lambda: empty_create(
- inner_t,
- AttrSource(source, attr),
- symbolic_context=(
- None
- if symbolic_context is None
- else symbolic_context.inner_contexts[attr]
- ),
- )
- )
- if self.copy_data:
- with torch.no_grad(), no_dispatch():
- r.real_tensor = torch.empty_strided(
- inner_t.size,
- inner_t.stride,
- dtype=inner_t.dtype,
- device=inner_t.device,
- )
- assert inner_t.data is not None
- _safe_copy(r.real_tensor, inner_t.data)
- return r
- transformed_tensors_dict = {
- attr: transform(attr, inner_t) for attr, inner_t in t.attrs.items()
- }
- sub = t.type.__tensor_unflatten__(
- transformed_tensors_dict, t.ctx, outer_size, outer_stride
- )
- # NB: Purposefully guard here to simplify the inner / outer symbols.
- # Using sym_eq() for symbolic comparison can result in an expression that's too
- # difficult to guard on, so we use == here.
- assert sub.shape == outer_size, (
- f"Expected return value from {t.type}__tensor_unflatten__() to have "
- f"shape equal to {outer_size}, but got: {sub.shape}"
- )
- assert sub.stride() == outer_stride, (
- f"Expected return value from {t.type}__tensor_unflatten__() to have "
- f"stride equal to {outer_stride}, but got: {sub.stride()}"
- )
- return sub
- # Returns an all-dynamic symbolic context used for metafying the given tensor with
- # fully dynamic dims. This is useful when fake-ifying intermediate tensors in
- # closed-over ViewFunc state, as we don't have symbolic contexts for them, but we
- # don't want to over-specialize during view replay.
- def all_dynamic_symbolic_context(
- t: MetaTensorDesc, source, shape_env, callback
- ):
- from torch._dynamo.source import AttrSource
- from torch.fx.experimental.symbolic_shapes import (
- DimDynamic,
- StatelessSymbolicContext,
- SubclassSymbolicContext,
- )
- view_base_context: Optional[SymbolicContext] = None
- if t.is_view:
- assert t.base is not None
- view_base_context = all_dynamic_symbolic_context(
- t.base, AttrSource(source, "_base"), shape_env, callback
- )
- t_symbolic_context: SymbolicContext
- t_dynamic_sizes = [DimDynamic.DYNAMIC] * t.ndim
- if t.is_traceable_wrapper_subclass:
- assert t.attrs is not None
- inner_contexts: Dict[str, SymbolicContext] = {}
- for attr, inner in t.attrs.items():
- assert isinstance(attr, str)
- inner_contexts[attr] = all_dynamic_symbolic_context(
- inner, AttrSource(source, attr), shape_env, callback
- )
- t_symbolic_context = SubclassSymbolicContext(
- dynamic_sizes=t_dynamic_sizes,
- constraint_sizes=[None] * t.ndim,
- inner_contexts=inner_contexts,
- tensor_source=source,
- view_base_context=view_base_context,
- )
- else:
- t_symbolic_context = StatelessSymbolicContext(
- dynamic_sizes=t_dynamic_sizes,
- constraint_sizes=[None] * t.ndim,
- view_base_context=view_base_context,
- )
- return t_symbolic_context
- # Returns a fake-ified version of an input view tensor t, given an already fake-ified
- # base. At a high level, we want two things:
- # 1. fake_t should have the same view relationship to the given fake base as the
- # input t has to its _base.
- # 2. fake_t should have symbolic sizes / strides / storage offset according to the
- # appropriate symbolic context (i.e. from the automatic dynamic algorithm).
- #
- # We currently take different strategies across view types:
- # * For dense -> dense views, accomplish both (1) and (2) simultaneously via an
- # as_strided() call on the fake-ified base, passing symbolic metadata.
- # * For views involving subclasses, perform view replay using view funcs to
- # achieve (1). It's necessary for (2) to swap out any closed-over state in
- # the view funcs with symbolicized SymInts and fake-ified tensors. Doing this
- # avoids specialization (and thus over-eager simplification of symbols) that
- # could occur during view replay on the fake-ified base.
- #
- # Examples:
- # * t.unsqueeze(-1) with dense t is a dense -> dense view. It can be modeled
- # with an as_strided() call on the fake base passing symbolic metadata.
- # * sub.select(dim=0, index=3) is a subclass -> subclass view. The index arg
- # is made symbolic to avoid invalid specialization and view replay is then
- # done to reconstruct the view.
- # * _nested_from_jagged(values, offsets) is a dense -> subclass view
- # that returns a subclass instance from a dense values tensor. The offsets
- # tensor is closed over in the view func, as it can be considered view metadata.
- # First, the offsets tensor is fake-ified according to the inner symbolic
- # context and with the correct relationship to the outer size / stride metadata.
- # Then view replay is done, swapping in the fake offsets so the view replay output
- # is fully fake with no invalid specialization.
- def view_from_base(
- base: torch.Tensor, t: MetaTensorDesc, source=source, shape_env=shape_env
- ):
- # fake-ify t's metadata according to the outer symbolic context
- (sizes, strides, storage_offset) = sym_sizes_strides_storage_offset(
- t, source
- )
- if (
- not t.is_traceable_wrapper_subclass
- and not is_traceable_wrapper_subclass(base)
- ):
- # Dense -> Dense view case uses as_strided() to construct view relationship.
- # TODO: Change this logic to use view replay for consistency?
- # It's likely there is no view func available.
- with maybe_suppress():
- return base.as_strided(sizes, strides, storage_offset)
- from torch._dynamo.source import EphemeralSource
- from torch.fx.experimental.symbolic_shapes import (
- StatelessSymbolicContext,
- sym_eq,
- )
- def symint_visitor_fn(s):
- nonlocal symbolic_context
- from torch.fx.experimental.symbolic_shapes import DimDynamic
- all_static_sizes = (
- symbolic_context is not None
- and isinstance(symbolic_context, StatelessSymbolicContext)
- and all(
- x is DimDynamic.STATIC for x in symbolic_context.dynamic_sizes
- )
- )
- # Can't just rely on shape env being None - dynamo always initializes it
- if all_static_sizes or shape_env is None:
- return s
- # NB: The symbol here is expected to be simplified out because we a priori
- # allocate inner and outer symbols according to the appropriate symbolic
- # contexts and prefer those over this symbol during symbol simplification
- # (via usage of EphemeralSource below). This -shouldn't- happen, but if
- # this symbol somehow leaks out beyond the view tensor's shape metadata, our
- # assumption of it being simplified out will fail and it may be guarded on,
- # which will hard error.
- sym_source = EphemeralSource("symint_visitor_fn")
- symbol = shape_env.create_symbol(s, sym_source)
- return shape_env.create_symintnode(symbol, hint=s, source=sym_source)
- real_to_fake_mapping = {}
- if t.is_traceable_wrapper_subclass:
- assert t.attrs is not None
- # NB: t.ctx could be None if the subclass in question has no
- # meaningful context
- assert t.type is not None
- # Fake-ify t naively here; this is only done so we can get fake-ified inner
- # tensors with the correct relationships to the outer sizes / strides for use
- # in view replay. It's done beforehand here because it's not easy to do when
- # visiting tensors one-by-one during view replay.
- #
- # Example:
- # Consider a Dense -> NJT view. NJT has (values, offsets) components and we
- # want a view of values with the offsets closed over. As the offsets component
- # is needed to describe the output view, it's important that it's fakeified
- # correctly.
- fake_t = empty_create_subclass(
- t, outer_size=sizes, outer_stride=strides
- )
- attrs, _ = fake_t.__tensor_flatten__()
- for attr in attrs:
- real_to_fake_mapping[t.attrs[attr].id] = getattr(fake_t, attr)
- def tensor_visitor_fn(
- visited_t: torch.Tensor,
- # These arguments are never passed, we just use them to close
- # over these relevant values
- shape_env=shape_env,
- callback=callback,
- ):
- # It's possible to close over an undefined tensor (e.g. NJT's lengths).
- if visited_t is None:
- return None
- # NB: visited_t being a Tensor here is very naughty! Should
- # have already been described
- # Fake inner tensors of view subclasses will come from the mapping built above.
- visited_id = self.describer.get_tensor_id(visited_t)
- fake_visited_t = real_to_fake_mapping.get(visited_id, None)
- if fake_visited_t is not None:
- return fake_visited_t
- visited_desc = self.describer.describe_tensor(visited_t)
- # For other closed-over tensor state, fake-ify it as all dynamic with an
- # ephemeral source. This avoids invalid specialization during view replay.
- # If we find that in practice the usage of ephemeral sources isn't enough
- # to guarantee that we don't have guards on these symbols, we may need to
- # explicitly suppress guards (as is done for _base in the dense -> dense
- # view case).
- temp_source = EphemeralSource("tensor_visitor_fn")
- return self.meta_tensor(
- visited_desc,
- shape_env,
- callback,
- source=temp_source,
- symbolic_context=all_dynamic_symbolic_context(
- visited_desc, temp_source, shape_env, callback
- ),
- )
- # Replay the view, swapping out any non-symbolic SymInts or real tensors
- # for symbolic SymInts or fake tensors.
- assert t.view_func is not None
- # NB: we do NOT suppress guards here, we need to remove ephemeral
- # sources
- fake_t = t.view_func(base, symint_visitor_fn, tensor_visitor_fn)
- # Ensure the output has symbolic shapes according to the outer symbolic context.
- # These checks should simplify out any symbols created for closed-over view func
- # SymInts.
- torch._check(sym_eq(fake_t.size(), sizes))
- torch._check(sym_eq(fake_t.stride(), strides))
- torch._check(sym_eq(fake_t.storage_offset(), storage_offset))
- return fake_t
- if self.get_tensor_memo(t) is None:
- GRAD_TENSOR_SENTINEL_VALUE = -2
- with torch.inference_mode(t.is_inference):
- if t.is_sparse:
- is_leaf = t.is_leaf
- # The lambda function below is similar to
- # `t.to(device='meta')` except the latter
- # preserves nnz value
- r = callback(
- lambda: torch.ops.aten._sparse_coo_tensor_with_dims(
- t.sparse_dim,
- t.dense_dim,
- t.size,
- dtype=t.dtype,
- layout=torch.sparse_coo,
- device="meta",
- )
- )
- if self.copy_data:
- # Pray that sparse clone doesn't lose information
- assert t.data is not None
- with torch.no_grad(), no_dispatch():
- r.real_tensor = _safe_clone(t.data)
- assert safe_is_leaf(r), "the callback you passed in doesn't detach"
- # Note [is_coalesced is dispatched]
- # Strangely enough, is_coalesced() is a dispatched operator,
- # which means that it will get caught by fake tensor mode.
- # Ordinarily this would error, but there's some logic in
- # fake tensor ensure this doesn't happen.
- r._coalesced_(t.is_coalesced)
- if t.requires_grad:
- r.requires_grad = True
- if t.requires_grad and not is_leaf:
- # This should probably use DelayedError,
- # but clone is fine for now for sparse tensors.
- # (DelayedError does not work for sparse because it causes
- # the Fake sparse tensor to "lose" its fakeness)
- r = r.clone()
- with torch.enable_grad():
- r._coalesced_(t.is_coalesced)
- elif is_sparse_compressed_layout(t.layout):
- is_leaf = t.is_leaf
- if t.layout in {torch.sparse_bsr, torch.sparse_bsc}:
- assert t.sparse_dim is not None
- assert t.dense_dim is not None
- assert t.values is not None
- batch_dim = t.ndim - t.sparse_dim - t.dense_dim
- blocksize = t.values.shape[batch_dim + 1 : batch_dim + 3]
- else:
- blocksize = ()
- if t.layout in {torch.sparse_csr, torch.sparse_bsr}:
- assert t.crow_indices is not None
- index_dtype = t.crow_indices.dtype
- else:
- assert t.ccol_indices is not None
- index_dtype = t.ccol_indices.dtype
- r = callback(
- lambda: torch.ops.aten._sparse_compressed_tensor_with_dims(
- 0,
- t.dense_dim,
- t.shape,
- blocksize,
- index_dtype,
- layout=t.layout,
- dtype=t.dtype,
- device="meta",
- )
- )
- if self.copy_data:
- # Pray sparse clone doesn't lose information
- assert t.data is not None
- with torch.no_grad(), no_dispatch():
- r.real_tensor = _safe_clone(t.data)
- assert safe_is_leaf(r), "the callback you passed in doesn't detach"
- if t.requires_grad:
- r.requires_grad = True
- if t.requires_grad and not is_leaf:
- r = torch._C._functions.DelayedError(
- "Internal error: Tried to backward() through example input",
- 1,
- )(r)
- elif t.is_nested and not t.is_traceable_wrapper_subclass:
- # TODO: Handle this better in Dynamo?
- # There are checks there now, but this can still be triggered by a dense
- # tensor graph input that is a view of a strided NT.
- from torch._dynamo.exc import unimplemented
- unimplemented(
- "strided nested tensors are not supported by meta conversion"
- )
- elif t.is_mkldnn:
- is_leaf = t.is_leaf
- sizes, strides, _storage_offset = sym_sizes_strides_storage_offset(
- t, source
- )
- # TODO: This doesn't seem right, where's the MKLDNN'ness
- # lol
- r = callback(
- lambda: torch.empty_strided(
- sizes, strides, dtype=t.dtype, device="meta"
- )
- )
- if self.copy_data:
- with torch.no_grad(), no_dispatch():
- assert t.size is not None
- assert t.stride is not None
- r.real_tensor = torch.empty_strided(
- t.size, t.stride, dtype=t.dtype, device=t.device
- )
- assert t.data is not None
- _safe_copy(r.real_tensor, t.data)
- assert safe_is_leaf(r), "the callback you passed in doesn't detach"
- if t.requires_grad:
- r.requires_grad = True
- if t.requires_grad and not is_leaf:
- r = torch._C._functions.DelayedError(
- "Internal error: Tried to backward() through example input",
- 1,
- )(r)
- elif t.is_functorch_wrapped:
- if t.is_view:
- from torch._dynamo.exc import unimplemented
- unimplemented(
- "view functorch tensors are not supported by meta conversion"
- )
- # Wraps a functorch tensor class (BatchedTensor, GradTrackingTensor)
- # in a FakeTensor
- def _to_fake_tensor(t: MetaTensorDesc):
- # TODO: why aren't the recursive calls going to
- # meta_tensor
- if t.is_batchedtensor:
- assert t.unwrapped is not None
- assert t.level is not None
- assert t.bdim is not None
- ft = _to_fake_tensor(t.unwrapped)
- lvl = t.level
- bdim = t.bdim
- # You cannot create functorch tensors without
- # having the ambient funtorch interpreter stack
- # available, as the level refers to things in the
- # stack
- with torch._functorch.pyfunctorch.temporarily_restore_interpreter_stack(
- t.functorch_stack
- ):
- r = _add_batch_dim(ft, bdim, lvl)
- elif t.is_gradtrackingtensor:
- assert t.unwrapped is not None
- assert t.level is not None
- disable_functorch = torch._C._DisableFuncTorch
- with disable_functorch():
- ft = _to_fake_tensor(t.unwrapped)
- lvl = t.level
- if lvl == GRAD_TENSOR_SENTINEL_VALUE:
- r = ft
- else:
- with torch._functorch.pyfunctorch.temporarily_restore_interpreter_stack(
- t.functorch_stack
- ):
- r = torch._C._functorch._wrap_for_grad(ft, lvl)
- is_leaf = t.is_leaf
- if t.requires_grad and safe_is_leaf(r):
- r.requires_grad = True
- elif t.requires_grad and not is_leaf:
- r = torch._C._functions.DelayedError( # type: ignore[assignment]
- "Internal error: Tried to backward() through example input",
- 1,
- )(
- r # type: ignore[arg-type]
- )
- elif t.is_functional:
- assert t.unwrapped is not None
- assert t.current_level is not None
- ft = self.meta_tensor(
- t.unwrapped,
- shape_env=shape_env,
- callback=callback,
- # NB: reuse these exactly, we treat the
- # functional tensor as "invisible".
- # TODO: Actually this all probably doesn't
- # work, take a closer look.
- source=source,
- symbolic_context=symbolic_context,
- )
- r = _wrap_functional_tensor(ft, t.current_level)
- # TODO: is_leaf/requires_grad?
- else:
- assert t.stride is not None
- sizes = t.size
- strides = t.stride
- r = callback(
- lambda: torch.empty_strided(
- sizes,
- strides,
- dtype=t.dtype,
- device="meta",
- )
- )
- if self.copy_data:
- with torch.no_grad(), no_dispatch():
- r.real_tensor = torch.empty_strided( # type: ignore[attr-defined]
- t.size,
- t.stride,
- dtype=t.dtype,
- device=t.device,
- )
- assert t.data is not None
- _safe_copy(r.real_tensor, t.data) # type: ignore[attr-defined]
- return r
- r = _to_fake_tensor(t)
- elif t.is_functional and t.device.type not in ["xla", "lazy"]:
- assert t.unwrapped is not None
- assert not t.is_functorch_wrapped # handled above
- unwrapped = self.meta_tensor(
- t.unwrapped,
- shape_env=shape_env,
- callback=callback,
- source=source,
- symbolic_context=symbolic_context,
- )
- r = torch._to_functional_tensor(unwrapped)
- torch._mirror_autograd_meta_to(t.autograd_meta_from, r) # type: ignore[attr-defined]
- elif t.is_view:
- # Construct views in two steps: recursively meta-fy their
- # base, and then create view(s) off that. NB: doing it
- # directly from storage is WRONG because this won't cause
- # version counters to get shared.
- assert t.base is not None
- base_symbolic_context = None
- if shape_env and symbolic_context is not None:
- from torch.fx.experimental.symbolic_shapes import (
- StatelessSymbolicContext,
- )
- assert isinstance(symbolic_context, StatelessSymbolicContext)
- # NB: This should generally be set when the input is a view,
- # but the exception right now is for fake-ifying grads, which is
- # a work in progress.
- if symbolic_context.view_base_context is not None:
- base_symbolic_context = symbolic_context.view_base_context
- base = self.meta_tensor(
- t.base,
- shape_env,
- callback,
- source=torch._dynamo.source.AttrSource(source, "_base"),
- symbolic_context=base_symbolic_context,
- )
- def is_c_of_r(complex_dtype, real_dtype):
- return (
- utils.is_complex_dtype(complex_dtype)
- and utils.corresponding_real_dtype(complex_dtype)
- == real_dtype
- )
- # In some situations, MetaConverter may be called in a
- # context where autograd is disabled. For the _is_view
- # assert to pass, we have to setup the autograd view
- # metadata anyway. Do this by reenabling the
- # ADInplaceOrView key. This is kind of a hack.
- old_exclude = torch._C._dispatch_tls_is_dispatch_key_excluded(
- torch._C.DispatchKey.ADInplaceOrView
- )
- torch._C._dispatch_tls_set_dispatch_key_excluded(
- torch._C.DispatchKey.ADInplaceOrView, False
- )
- try:
- if base.dtype == t.dtype:
- pass
- elif is_c_of_r(base.dtype, t.dtype):
- base = torch.view_as_real(base)
- elif is_c_of_r(t.dtype, base.dtype):
- base = torch.view_as_complex(base)
- else:
- # This is not guaranteed to succeed. If it fails, it
- # means there is another dtype-converting view function
- # that hasn't been handled here
- base = base.view(t.dtype)
- # This is very tricky. Naively, you might expect this
- # to hold:
- #
- # if t.requires_grad and not safe_is_leaf(t)
- # assert t._base.requires_grad
- #
- # But it's not true! As you can see in the following
- # program:
- #
- # x = torch.zeros(4)
- # y = x.view(1, 4)
- # y.requires_grad = True
- # z = y.view(1, 1, 4)
- # assert z._base is x
- #
- # So we may have to do *two* views out of the base to
- # recreate this situation.
- if t.is_leaf:
- # Leaf views that track view metadata are created by
- # creating a view inside a no_grad block
- with torch.no_grad():
- r = view_from_base(base, t)
- # As it's a leaf, we can directly assign requires_grad
- r.requires_grad = t.requires_grad
- else:
- if t.base.requires_grad == t.requires_grad:
- # Easy case, just run the view op
- with torch.enable_grad():
- r = view_from_base(base, t)
- # NB: We don't actaully faithfully replicate
- # autograd connectivity, but that doesn't matter
- # today. See following for more info:
- # https://gist.github.com/soulitzer/e03f015b314c3f5fcf80888c69390913
- else:
- # Obscure case. Create a leaf view and give it the
- # correct requires_grad, then do the final view.
- # NB: Can't have a non-leaf without requiring grad!
- assert t.requires_grad
- with torch.no_grad():
- mid = base.view(base.shape)
- mid.requires_grad = t.requires_grad
- with torch.enable_grad():
- r = view_from_base(mid, t)
- # The CreationMeta influences whether or not inplace
- # mutation is an error or not. So we need to make
- # sure we properly propagate this as well.
- assert t.creation_meta is not None
- torch._C._autograd._set_creation_meta(r, t.creation_meta)
- finally:
- torch._C._dispatch_tls_set_dispatch_key_excluded(
- torch._C.DispatchKey.ADInplaceOrView, old_exclude
- )
- else:
- is_leaf = t.is_leaf
- # Graph-Break for wrapped tensors
- if (
- not (t.is_batchedtensor or t.is_gradtrackingtensor)
- and t.is_functorch_wrapped
- ) or t.is_legacy_batchedtensor:
- return NotImplemented
- (
- sizes,
- strides,
- storage_offset,
- ) = sym_sizes_strides_storage_offset(t, source, symbolic_context)
- # If we have a subclass that desugars into dense tensors,
- # perform our callback on each inner tensor.
- if t.is_traceable_wrapper_subclass:
- r = empty_create_subclass(
- t, outer_size=sizes, outer_stride=strides
- )
- else:
- r = callback(
- lambda: torch.empty_strided(
- sizes,
- strides,
- dtype=t.dtype,
- device="meta",
- )
- )
- if self.copy_data:
- with torch.no_grad(), no_dispatch():
- assert t.size is not None
- assert t.stride is not None
- r.real_tensor = torch.empty_strided(
- t.size, t.stride, dtype=t.dtype, device=t.device
- )
- _safe_copy(r.real_tensor, t.data)
- assert safe_is_leaf(r), "the callback you passed in doesn't detach"
- if t.requires_grad:
- r.requires_grad = t.requires_grad
- if not is_leaf:
- # Fake up some autograd history.
- # Note: we *used* to call .clone() here to mock up some autograd history.
- # This is bad for subclasses.
- # Consider the case where you have a wrapper subclass that is contiguous,
- # but its inner tensor is noncontiguous().
- # .clone() (or other ops) will have the side effect of changing
- # the metadata of the inner tensor.
- # So instead, we now have a dedicated fn to set autograd history,
- # without inadvertently changing other metadata.
- r = torch._C._functions.DelayedError(
- "Internal error: Tried to backward() through example input",
- 1,
- )(r)
- s = t.storage
- assert s is not None
- if s.id not in self.storage_memo and (
- r.is_nested
- or (
- r.stride() == strides
- and r.storage_offset() == storage_offset
- )
- ):
- # You're normal and happy, install the fresh storage into the memo
- self.set_storage_memo(s, r.untyped_storage())
- if self.copy_data:
- r.untyped_storage().real_storage = (
- r.real_tensor.untyped_storage()
- )
- else:
- # You're in crazy town; somehow you gave us a tensor
- # that wasn't a view, but had nonzero storage offset,
- # nontrivial strides (such that clone() couldn't
- # preserve them), or already aliases with another
- # tensor's storage. The most typical way to end
- # up here is with set_. So use set_ to bludgeon this
- # in.
- r_s = self.meta_storage(s, callback=callback)
- # NB: In principle, this should always work, but there
- # is some subtle difference in the autograd metadata
- # that means we will backprop the set_ call, even if
- # r is declared as an input to grad.
- # See https://github.com/pytorch/pytorch/issues/87956
- # for the reproducer.
- # NB: The in_kernel_invocation_manager here is necessary
- # for fake tensor. If we run the set_ call with fake
- # tensor on, r will improperly report that it is NOT a
- # meta tensor but a cpu tensor, and then the set_ call
- # will fail due to device mismatch. no_dispatch() is
- # not enough, because the fake tensor will still claim
- # to be a CPU tensor and you'll end up in the CPU
- # kernel. Arguably this is a hack; a cleaner way to
- # solve this is to have a FakeStorage concept which
- # would report it's CPU device--no problem now! But
- # this is difficult to do because we don't have storage
- # subclasses. Relevant test is
- # DynamicShapesFunctionTests::test_add_dynamic_shapes in
- # test/dynamo/test_dynamic_shapes.py
- maybe_fake_mgr: ContextManager[None] = contextlib.nullcontext()
- from torch._subclasses.fake_tensor import (
- in_kernel_invocation_manager,
- maybe_get_fake_mode,
- )
- mb_fake_mode = maybe_get_fake_mode(r)
- if mb_fake_mode is not None:
- maybe_fake_mgr = in_kernel_invocation_manager(mb_fake_mode)
- with torch.no_grad(), maybe_suppress():
- with maybe_fake_mgr:
- r.set_(r_s, storage_offset, sizes, strides)
- if self.copy_data:
- with torch.no_grad(), no_dispatch():
- r.real_tensor.set_(
- r_s.real_storage,
- t.storage_offset,
- t.size,
- t.stride,
- )
- if t.grad is not None:
- from torch._dynamo.source import AttrSource
- # TODO: Use a valid grad-specific symbolic context instead of recycling
- # the one from t. This isn't correct if e.g. t._is_view() != t.grad._is_view().
- r.grad = self.meta_tensor(
- t.grad,
- shape_env,
- callback,
- source=AttrSource(source, "grad"),
- symbolic_context=symbolic_context,
- )
- torch._C._set_conj(r, t.is_conj)
- torch._C._set_neg(r, t.is_neg)
- # This can be skipped if necessary for performance reasons
- skip_leaf = (
- t.is_gradtrackingtensor and t.level == GRAD_TENSOR_SENTINEL_VALUE
- )
- assert_metadata_eq(assert_eq, t, r, skip_symbolic=True, skip_leaf=skip_leaf)
- # Thanks to storage resizing, it's possible to end up with a tensor
- # that advertises a real size, but has a storage that actually has zero bytes.
- # Need to reflect this in the generated FakeTensor.
- if t.storage is not None and t.storage.size == 0:
- r.untyped_storage().resize_(0)
- if t.is_parameter:
- r._is_param = True
- self.set_tensor_memo(t, r)
- return self.get_tensor_memo(t)
- def __call__(
- self,
- t,
- shape_env=None,
- *,
- callback=lambda t: t(),
- source=None,
- symbolic_context=None,
- # Controls whether or not we should dump the tensor metadata to structured logs
- # when source is not None. Because we refakify after Dynamo is done,
- # we don't want to dump info again from AOTAutograd, it is redundant.
- trace=True,
- ):
- # TODO: zero tensors? We appear to have eliminated them by
- # excluding complex for now
- # Filter out cases we don't support
- # TODO: This can probably be simplified quite a bit
- if isinstance(t, torch.Tensor) or is_traceable_wrapper_subclass(t):
- if (
- # Lazy tensors are not supported. Note that XLA is
- # implemented on top of lazy tensor, not excluded here; we
- # have some special handling for it; this is for XLA Dynamo
- # integration
- t.device.type == "lazy"
- or
- # Quantization is not supported
- t.is_quantized
- or
- # Views out of sparse tensors not currently supported (plain
- # sparse is supported htough)
- (t._is_view() and t._base is not None and t._base.is_sparse)
- ):
- self.miss += 1
- return NotImplemented
- else:
- self.hit += 1
- elif torch.overrides.is_tensor_like(t):
- self.miss += 1
- return NotImplemented
- else:
- # non-Tensor types don't count as hit or miss
- return t
- if source is None:
- trace = False
- # Describe the tensor. NB: do NOT disable ambient modes, we may need
- # to query them when figuring out what to put in here
- t_desc = self.describer.describe_tensor(t, trace=trace)
- if trace:
- trace_structured(
- "describe_source",
- metadata_fn=lambda: {
- "describer_id": self.describer.id,
- "id": t_desc.id,
- "source": source.name(),
- },
- )
- # Do the meta-fication. Here, we disable all the ambient modes, to
- # better simulate what would be like to re-fakeify from a fresh
- # process
- with contextlib.ExitStack() as exit_stack:
- exit_stack.enter_context(torch._dispatch.python.suspend_functionalization())
- st = peek_interpreter_stack()
- if st is not None:
- exit_stack.enter_context(
- torch._functorch.pyfunctorch.temporarily_clear_interpreter_stack()
- )
- r = self.meta_tensor(
- t_desc,
- shape_env=shape_env,
- callback=callback,
- source=source,
- symbolic_context=symbolic_context,
- )
- if type(t) is torch.nn.Parameter:
- # NB: Cannot directly use Parameter constructor
- # because that would force a detach, not desirable
- r._is_param = True
- # TODO: return the description for later
- return r
- import torch._prims_common as utils
|