| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628 |
- # mypy: allow-untyped-defs
- import contextlib
- import warnings
- from dataclasses import dataclass
- from typing import Any, Dict, List, Optional, Set, Union
- import torch
- import torchgen
- import torchgen.model
- from torch._C import (
- _get_dispatch_stack_at,
- _len_torch_dispatch_stack,
- _pop_torch_dispatch_stack,
- _push_on_torch_dispatch_stack,
- DispatchKey,
- )
- # TODO: Limitations and things about enable_torch_dispatch_mode we should fix before exposing it:
- # - We need a better user-facing api for _DisableTorchDispatch that
- # is able to selectively disable __torch_dispatch__ of a particular class.
- # - It doesn't work with the tensor constructors (torch.tensor, torch.Tensor)
- # - Better name (see https://github.com/pytorch/pytorch/pull/63496#discussion_r694091694)
- _is_in_torch_dispatch_mode = False
- def is_in_torch_dispatch_mode() -> bool:
- return _is_in_torch_dispatch_mode
- class TorchDispatchMode:
- """
- A ``TorchDispatchMode`` allows you to override the meaning of all
- ``__torch_dispatch__`` overrideable functions within a dynamic scope,
- without having to actually create a tensor subclass or manually
- monkey-patch functions in the PyTorch API. Some common situations
- where you should use a mode:
- * You want to override the meaning of factory functions, or other
- functions that do not otherwise take a tensor as an argument
- (these cannot be overridden with tensor subclasses).
- * You want to override the behavior of all functions without needing
- to wrap your inputs in tensor subclasses; e.g., if you are just
- interested in logging intermediate computations.
- * You want to control the order of execution of various tensor
- subclasses explicitly, rather than implicitly via the return of
- ``NotImplemented``.
- Independent subclasses of :class:`TorchDispatchMode` are compositional:
- modes can be pushed onto a stack using ``with MyMode():``.
- When you call functions in the PyTorch API inside your
- ``__torch_dispatch__`` implementation, by default, they will forward on to
- the next mode on the mode stack. If you want recursively call back into
- your current ``__torch_dispatch__`` implementation, either explicitly
- invoke ``self.__torch_dispatch__(...)``, or use the context manager
- ``__torch_dispatch__(self)`` to make PyTorch
- API self-referential (beware of infinite loops, in this case!)
- """
- def __init__(self, _dispatch_key=None):
- if _dispatch_key is not None:
- assert isinstance(_dispatch_key, torch._C.DispatchKey)
- self.__dict__["_dispatch_key"] = _dispatch_key
- self.old_dispatch_mode_flag = False
- def __torch_dispatch__(self, func, types, args=(), kwargs=None):
- raise NotImplementedError
- def __enter__(self):
- global _is_in_torch_dispatch_mode
- self.old_dispatch_mode_flag = _is_in_torch_dispatch_mode
- _is_in_torch_dispatch_mode = True
- _push_mode(self)
- return self
- def __exit__(self, exc_type, exc_val, exc_tb):
- mb_dk_or_mode_key = self.__dict__.get("_dispatch_key", None)
- if mb_dk_or_mode_key is None:
- # Today, mode keys are not used at all in the per-dispatch-key-mode logic (for pre-dispatch)
- # We should probably revisit this.
- mb_dk_or_mode_key = self.__dict__.get("_mode_key", None)
- global _is_in_torch_dispatch_mode
- _is_in_torch_dispatch_mode = self.old_dispatch_mode_flag
- _pop_mode(mb_dk_or_mode_key)
- @classmethod
- def push(cls, *args, **kwargs):
- warnings.warn(
- "`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`"
- )
- instance = cls(*args, **kwargs)
- return instance
- def _get_current_dispatch_mode():
- stack_len = _len_torch_dispatch_stack()
- # Return a user mode on the stack if there are any
- if stack_len > 0:
- return _get_dispatch_stack_at(stack_len - 1)
- return None
- def _detect_infra_mode(key):
- assert key in [torch._C._TorchDispatchModeKey.FUNCTIONAL, torch._C._TorchDispatchModeKey.PROXY]
- from torch._ops import _get_dispatch_mode_pre_dispatch
- pre_dispatch_mode = _get_dispatch_mode_pre_dispatch(
- key
- )
- post_dispatch_mode = torch._C._get_dispatch_mode(
- key
- )
- assert (pre_dispatch_mode is None) or (
- post_dispatch_mode is None
- )
- if pre_dispatch_mode is None:
- return post_dispatch_mode
- return pre_dispatch_mode
- def _unset_infra_mode(key):
- from torch._ops import _get_dispatch_mode_pre_dispatch, unset_mode_pre_dispatch
- pre_dispatch_mode = _get_dispatch_mode_pre_dispatch(key)
- post_dispatch_mode = torch._C._get_dispatch_mode(key)
- if pre_dispatch_mode and post_dispatch_mode:
- raise AssertionError(
- "Can't have active infra mode on both pre and post dispatch mode stack"
- )
- if pre_dispatch_mode:
- mode = unset_mode_pre_dispatch(key)
- return mode
- if post_dispatch_mode:
- return torch._C._unset_dispatch_mode(key)
- def _disable_infra_mode(key):
- assert key in (
- torch._C._TorchDispatchModeKey.FUNCTIONAL,
- torch._C._TorchDispatchModeKey.PROXY,
- )
- mode_unset = _unset_infra_mode(key)
- try:
- yield mode_unset
- finally:
- if mode_unset is not None:
- _push_mode(mode_unset)
- def _get_current_dispatch_mode_stack():
- stack_len = _len_torch_dispatch_stack()
- return [_get_dispatch_stack_at(i) for i in range(stack_len)]
- def _push_mode(mode: TorchDispatchMode):
- k = mode._dispatch_key if hasattr(mode, "_dispatch_key") else None
- assert k is None or k == torch._C.DispatchKey.PreDispatch
- if k is None:
- _push_on_torch_dispatch_stack(mode)
- return
- from torch._ops import _set_mode_pre_dispatch, get_cached_ops
- # See Note [Not Caching Per-Dispatch-Key Mode Handlers]
- # Clear the cache of every op that has been used so far, for this particular key.
- ks = torch._C._functionality_to_backend_keys(k)
- for op in get_cached_ops():
- for key in ks:
- op._uncache_dispatch(key)
- _set_mode_pre_dispatch(mode)
- def _pop_mode(k: Optional[Union[DispatchKey, torch._C._TorchDispatchModeKey]] = None):
- if k == torch._C.DispatchKey.PreDispatch: # type: ignore[attr-defined]
- from torch._ops import _pop_mode_from_pre_dispatch
- return _pop_mode_from_pre_dispatch()
- if k is None or isinstance(k, torch._C._TorchDispatchModeKey):
- return _pop_torch_dispatch_stack(k)
- @contextlib.contextmanager
- def _pop_mode_temporarily(k: Optional[DispatchKey] = None):
- old = _pop_mode(k)
- try:
- yield old
- finally:
- _push_mode(old)
- @contextlib.contextmanager
- def _disable_current_modes():
- from torch._ops import (
- _len_torch_dispatch_stack_pre_dispatch,
- _pop_mode_from_pre_dispatch,
- )
- from torch._subclasses.functional_tensor import FunctionalTensorMode
- from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode
- from torch._subclasses.schema_check_mode import SchemaCheckMode
- mode_len_pre_dispatch = _len_torch_dispatch_stack_pre_dispatch()
- old_pre_dispatch_modes = [
- _pop_mode_from_pre_dispatch() for _ in range(mode_len_pre_dispatch)
- ]
- has_proxy_mode_in_pre_dispatch = False
- has_functional_mode_in_pre_dispatch = False
- has_schema_check_mode_in_pre_dispatch = False
- for i in old_pre_dispatch_modes:
- if isinstance(i, ProxyTorchDispatchMode):
- has_proxy_mode_in_pre_dispatch = True
- if isinstance(i, FunctionalTensorMode):
- has_functional_mode_in_pre_dispatch = True
- if isinstance(i, SchemaCheckMode):
- has_schema_check_mode_in_pre_dispatch = True
- mode_len = _len_torch_dispatch_stack()
- old_modes = [_pop_mode() for _ in range(mode_len)]
- for old in old_modes:
- if (
- isinstance(old, FunctionalTensorMode)
- and has_functional_mode_in_pre_dispatch
- ):
- raise AssertionError(
- "Can't have FunctionalMode available both in PreDispatch and Python Key"
- )
- if isinstance(old, ProxyTorchDispatchMode) and has_proxy_mode_in_pre_dispatch:
- raise AssertionError(
- "Can't have ProxyTorchDispatchMode available both in PreDispatch and Python Key"
- )
- if (
- isinstance(old, SchemaCheckMode)
- and has_schema_check_mode_in_pre_dispatch
- ):
- raise AssertionError(
- "Can't have SchemaCheckMode available both in PreDispatch and Python Key"
- )
- # Manually disable proxy and fake modes, if any are active
- try:
- yield old_pre_dispatch_modes + old_modes
- finally:
- for mode in reversed(old_modes):
- _push_mode(mode)
- for mode in reversed(old_pre_dispatch_modes):
- _push_mode(mode)
- class BaseTorchDispatchMode(TorchDispatchMode):
- def __torch_dispatch__(self, func, types, args=(), kwargs=None):
- if kwargs is None:
- kwargs = {}
- return func(*args, **kwargs)
- def is_traceable_wrapper_subclass(t):
- """
- Returns whether or not a tensor subclass that implements __torch_dispatch__
- is 'traceable' with torch.compile.
- In order for a tensor subclass to support TorchDispatchMode-style tracing in PT2,
- It must implement two magic methods: __tensor_flatten__ and __tensor_unflatten__.
- It is also expected to obey some restrictions around traceability and aliasing:
- * The subclass's __torch_dispatch__() implementation should desugar into pytorch
- dispatcher operations that can be traced into a graph.
- * The subclass should use return_and_correct_aliasing(). This is needed today to make
- sure that torch.compile does the right thing in a few cases around input mutation
- and output aliasing.
- Expected magic method signatures:
- attrs, ctx = t.__tensor_flatten__()
- attrs: list of attribute name strings for inner tensors
- ctx: dict containing any other subclass-specific metadata needed for unflattening
- t = MySubClass.__tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride)
- inner_tensors: dict mapping attribute name -> tensor for each inner tensor
- ctx: dict with subclass metadata in the form that __tensor_flatten__() produces
- outer_size: expected (possibly symbolic) size that the returned subclass
- instance should have. Note that this arg is useful for certain subclasses
- that require the shape info to be constructed. In most cases, this arg can be
- safely ignored.
- outer_stride: expected (possibly symbolic) stride that the returned subclass
- instance should have. Note that this arg is useful for certain subclasses
- that require the stride info to be constructed. In most cases, this arg can be
- safely ignored.
- """
- is_subclass = isinstance(t, torch.Tensor) and type(t) != torch.Tensor
- return (
- is_subclass
- and hasattr(t, "__tensor_flatten__")
- and hasattr(t, "__tensor_unflatten__")
- )
- def transform_subclass(t, callback, outer_size=None, outer_stride=None):
- """
- Given a traceable, wrapper tensor subclass ``t`` that implements
- ``__torch_dispatch__`` and holds some inner tensors,
- and a callback of type ``Callable[[str, torch.Tensor], torch.Tensor]``,
- `transform_subclass` will construct a fresh instance of the wrapper tensor subclass.
- It will do so by grabbing each inner tensor attribute from the wrapper,
- passing them into ``callback`` to get a transformed tensor,
- and putting each transformed tensor into the fresh tensor subclass instance.
- Note: this function will not handle ensuring that the fresh subclass
- gets the same (autograd, and aliasing) metadata as the original tensor.
- This is generally handled in other subsystems like AOTAutograd.
- """
- outer_size = outer_size if outer_size is not None else t.size()
- outer_stride = outer_stride if outer_stride is not None else t.stride()
- attrs, ctx = t.__tensor_flatten__()
- transformed_tensors_dict = {}
- for attr in attrs:
- transformed_tensors_dict[attr] = callback(attr, getattr(t, attr))
- sub = type(t).__tensor_unflatten__(
- transformed_tensors_dict, ctx, outer_size, outer_stride
- )
- # NB: Purposefully guard here to simplify the inner / outer symbols.
- # Using sym_eq() for symbolic comparison can result in an expression that's too
- # difficult to guard on, so we use == here.
- assert sub.shape == outer_size, (
- f"Expected return value from {type(t)}__tensor_unflatten__() to have "
- f"shape equal to {outer_size}, but got: {sub.shape}"
- )
- assert sub.stride() == outer_stride, (
- f"Expected return value from {type(t)}__tensor_unflatten__() to have "
- f"stride equal to {outer_stride}, but got: {sub.stride()}"
- )
- return sub
- def _correct_storage_aliasing(func, schema_info, args, outs):
- """
- Given: an OpOverload, a SchemaInfo (cached information from torchgen about schema),
- and the inputs/outputs to the OpOverload,
- this function checks to see if func is a view operator
- (by checking if any of the outputs in the op's schema
- are immutable aliases of inputs).
- If so, this function manually aliases the storage of the output tensor
- with its corresponding input tensor alias.
- It does this by unsafely overwriting the storage field of the output tensor
- to be the same storage as the input.
- """
- assert isinstance(func, torch._ops.OpOverload)
- assert isinstance(args, tuple)
- assert isinstance(outs, (list, tuple))
- flat_outs = torch.utils._pytree.tree_leaves(outs)
- def alias_non_inplace_storage(arg, ret):
- # This is hopefully a reasonable assert:
- # subclasses that rely on this API for output aliasing
- # should always return wrapper tensor subclasses for us to manually alias.
- # in theory if a subclass that needs this API wants to sometimes return
- # plain tensors, we could remove the assert and just not perform the aliasing,
- # but it seems safer to learn more about this case first.
- if is_traceable_wrapper_subclass(arg) or is_traceable_wrapper_subclass(ret):
- ret_list = ret if isinstance(ret, list) else [ret]
- for r in ret_list:
- assert type(arg) == type(
- r
- ), f"""Called {str(func)} with input of type {type(arg)}
- and output of type {type(ret)}. But expected types to match."""
- # Need to run under no_dispatch, because we explicitly do **not**
- # want our subclass to intercept the set_() call.
- # instead, our subclass should directly have its storage swapped out.
- with torch.utils._mode_utils.no_dispatch():
- # See Note: [Fake Tensor Dispatch Keys]
- # we're borrowing the way it modifies dispatch key TLS.
- meta_in_tls = torch._C._meta_in_tls_dispatch_include()
- torch._C._set_meta_in_tls_dispatch_include(True)
- try:
- # directly calling this overload, and passing ret.shape, because we **explicitly**
- # don't want to reset the sizes on ret, if the storage implies a size change.
- # Why?
- # The purpose of this API is *not* to change the size/strides of our output- we assume it's already correct.
- # We just want to "fix up" the storage aliasing, without modifying or output's metadata.
- # Example: out = inp.expand(inp.shape[0], inp.shape[0])
- # This requires swapping the storage of out to be the same as inp,
- # but we do *not* want it to change the sizes/strides that were compute for out.
- if isinstance(ret, list):
- for r in ret:
- torch.ops.aten.set_.source_Storage_storage_offset(
- r,
- arg.untyped_storage(),
- r.storage_offset(),
- r.shape,
- r.stride(),
- )
- else:
- assert isinstance(ret, torch.Tensor), f"type: {type(ret)}"
- torch.ops.aten.set_.source_Storage_storage_offset(
- ret,
- arg.untyped_storage(),
- ret.storage_offset(),
- ret.shape,
- ret.stride(),
- )
- finally:
- torch._C._set_meta_in_tls_dispatch_include(meta_in_tls)
- def is_read_only_alias_match(arg, ret):
- shared_aliases = arg.alias_set & ret.alias_set
- return len(shared_aliases) > 0 and not arg.is_write
- num_args = len(func._schema.arguments)
- num_returns = len(func._schema.returns)
- for arg_idx in range(num_args):
- for return_idx in range(num_returns):
- if is_read_only_alias_match(
- schema_info.args[arg_idx], schema_info.outs[return_idx]
- ):
- alias_non_inplace_storage(args[arg_idx], outs[return_idx])
- # This abstracts over the fact that in return_and_correct_aliasing,
- # we sometimes use torchgen schema parsing (for aten ops, since torchscript's schema parsing is sometimes buggy),
- # and sometimes use torchscript schema parsing (for custom ops, for which torchgen parsing is untested).
- @dataclass
- class AliasInfo:
- alias_set: Set[str]
- is_write: bool
- name: Optional[str]
- @dataclass
- class SchemaInfo:
- args: List[AliasInfo]
- outs: List[AliasInfo]
- # Can't import torch._ops.OpOverload due to circular reference
- parsed_schema_map: Dict[Any, SchemaInfo] = {}
- # Given an OpOverload, returns schema information on it.
- # This is cached for efficiency, since it can involve running torchgen
- def get_alias_info(func) -> SchemaInfo:
- if func in parsed_schema_map:
- return parsed_schema_map[func]
- # For ATen ops: use torchgen (since torchscript parser doesn't handle alias annotations
- # properly for some ops that output tensorlists)
- if func.namespace == "aten":
- torchgen_schema_str = str(func._schema)
- assert torchgen_schema_str.startswith("aten::")
- # remove the aten:: namespace, which is added by the torchscript parser,
- # and torchgen doesn't know how to handle
- torchgen_schema_str = torchgen_schema_str[6:]
- import re
- # the torchscript parser ends up converting int[2]=1 into int[2]=[1, 1],
- # which torchgen chokes on.
- torchgen_schema_str = re.sub(r"=\[[0, ]+\]", "=0", torchgen_schema_str)
- torchgen_schema_str = re.sub(r"=\[[1, ]+\]", "=1", torchgen_schema_str)
- # for aten::rot90
- torchgen_schema_str = torchgen_schema_str.replace("=[0, 1]", "=[0,1]")
- torchgen_schema = torchgen.model.FunctionSchema.parse(torchgen_schema_str)
- arg_schemas = [
- AliasInfo(
- alias_set=(
- set() if a.annotation is None else set(a.annotation.alias_set)
- ),
- is_write=a.annotation is not None and a.annotation.is_write,
- name=a.name,
- )
- for a in torchgen_schema.arguments.flat_all
- ]
- out_schemas = [
- AliasInfo(
- alias_set=(
- set() if a.annotation is None else set(a.annotation.alias_set)
- ),
- is_write=a.annotation is not None and a.annotation.is_write,
- name=a.name,
- )
- for a in torchgen_schema.returns
- ]
- else:
- # For non-aten ops, torchgen is untested so we rely on torchscript schema parsing
- arg_schemas = [
- AliasInfo(
- alias_set=(
- set() if a.alias_info is None else set(a.alias_info.before_set)
- ),
- is_write=a.alias_info is not None and a.alias_info.is_write,
- name=a.name,
- )
- for a in func._schema.arguments
- ]
- out_schemas = [
- AliasInfo(
- alias_set=(
- set() if a.alias_info is None else set(a.alias_info.before_set)
- ),
- is_write=a.alias_info is not None and a.alias_info.is_write,
- name=a.name,
- )
- for a in func._schema.returns
- ]
- schema_info = SchemaInfo(args=arg_schemas, outs=out_schemas)
- parsed_schema_map[func] = schema_info
- return schema_info
- def return_and_correct_aliasing(func, args, kwargs, out):
- """
- This function should be used by wrapper tensor ``__torch_dispatch__`` subclasses
- that would like to work with torch.compile. It ensures that the subclass
- properly implements the aliasing behavior of every op,
- which is needed for correctness in AOTAutograd.
- This function will handle:
- * When we see a view op, we will alias the storages of any
- input and output tensor subclasses
- * When we see an inplace or out= op, we will directly
- return the corresponding input tensor, instead of returning
- a (potentially) fresh output tensor.
- """
- # Caching here because torchgen parsing is definitely not fast, and this function is called
- # once for every op in the graph during functionalization.
- schema_info = get_alias_info(func)
- def get_write_alias(x):
- if len(x.alias_set) == 0:
- return None
- alias_set = list(x.alias_set)
- # torchscript allows for complicated alias sets, but our dispatcher ops only really involve simple aliasing
- assert len(alias_set) == 1
- if x.is_write:
- return alias_set[0]
- return None
- def get_arg_from_alias(output_alias, schema_info, args, kwargs):
- new_args, new_kwargs = torch.fx.operator_schemas.normalize_function(
- func, args=args, kwargs=kwargs
- )
- arg_indices = [
- i for i, a in enumerate(schema_info.args) if output_alias in a.alias_set
- ]
- # For any dispatcher op with an output alias, we expect it to map to exactly one alias in the schema's input arguments.
- assert len(arg_indices) == 1
- idx = arg_indices[0]
- arg_info = schema_info.args[idx]
- if arg_info.name is not None and arg_info.name in new_kwargs:
- return new_kwargs[arg_info.name]
- return new_args[idx]
- # Fix up the storages of any outs so that they point to the same storage as the input,
- # if func is a view op.
- _correct_storage_aliasing(
- func, schema_info, args, (out,) if not isinstance(out, tuple) else out
- )
- # For inplace_view ops in particular, we'll try hard to make sure that the wrapper subclass's
- # metadata is set correctly.
- if torch.Tag.inplace_view in func.tags:
- # no_dispatch() to make sure that we secretly change the metadata on the wrapper,
- # but don't end up dispatching the op anywhere else.
- mutated_args = [
- x
- for i, x in enumerate(args)
- if get_write_alias(schema_info.args[i]) is not None
- ]
- # Assumption: we have a very small number of inplace_view ops that follow a strict schema:
- # there is only a single argument that gets its metadata mutated.
- assert len(mutated_args) == 1
- # This check exists because we generally *do* want to update the metadata of any wrapper subclasses,
- # but FunctionalTensor is special: it overrides all size/stride calls to plumb to the inner tensor.
- # so we don't actually need to update the metadata (and attempting to do so causes errors)
- from torch._subclasses.functional_tensor import FunctionalTensor
- if not isinstance(mutated_args[0], FunctionalTensor):
- with torch.utils._mode_utils.no_dispatch():
- # See Note: [Fake Tensor Dispatch Keys]
- # we're borrowing the way it modifies dispatch key TLS.
- meta_in_tls = torch._C._meta_in_tls_dispatch_include()
- torch._C._set_meta_in_tls_dispatch_include(True)
- try:
- func(*args, **kwargs)
- finally:
- torch._C._set_meta_in_tls_dispatch_include(meta_in_tls)
- # Next: we need to make sure to return inputs directly, if the output is a mutable alias (e.g. add_()).
- # simple case: none of our outputs have mutable aliases, so we can return the output as-is
- if not any(get_write_alias(r) is not None for r in schema_info.outs):
- return out
- # simplifying assumption: we don't have **any** ops with return types like "-> (Tensor(a!), Tensor)"
- if not all(get_write_alias(r) is not None for r in schema_info.outs):
- raise RuntimeError("Unsupported schema: " + str(func._schema))
- if len(func._schema.returns) == 1:
- return get_arg_from_alias(
- get_write_alias(schema_info.outs[0]), schema_info, args, kwargs
- )
- # In the multi-return case, all aten ops return a tuple / list, so cast accordingly.
- outs_to_return = type(out)(
- [
- (
- get_arg_from_alias(
- get_write_alias(schema_info.outs[i]), schema_info, args, kwargs
- )
- if get_write_alias(r) is not None
- else o
- )
- for ((i, r), o) in zip(enumerate(schema_info.outs), out)
- ]
- )
- return outs_to_return
|