| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250 |
- # mypy: allow-untyped-defs
- import operator
- from collections import defaultdict
- from typing import Any, Callable, DefaultDict, Dict, Optional, Tuple, Type
- import sympy
- import torch
- import torch.fx
- from torch.fx.experimental.symbolic_shapes import (
- compute_unbacked_bindings,
- rebind_unbacked,
- statically_known_true,
- sym_eq,
- )
- from torch.utils import _pytree as pytree
- from torch.utils._pytree import tree_map
- from .virtualized import V
- # Check the pattern: (nn.module, F.function/torch.Tensor.method) matched.
- # Works for length 2 patterns with 1 module and 1 function/method.
- def matches_module_function_pattern(
- pattern: Tuple[Type[torch.nn.modules.Module], Callable[..., Any]],
- node: torch.fx.node.Node,
- modules: Dict[str, torch.nn.modules.Module],
- ) -> bool:
- if len(node.args) == 0:
- return False
- if not isinstance(node.args[0], torch.fx.Node) or not isinstance(
- node, torch.fx.Node
- ):
- return False
- # the first node is call_module
- if node.args[0].op != "call_module":
- return False
- if not isinstance(node.args[0].target, str):
- return False
- if node.args[0].target not in modules:
- return False
- if type(modules[node.args[0].target]) is not pattern[0]:
- return False
- # the second node is call_function or call_method
- if node.op != "call_function" and node.op != "call_method":
- return False
- if node.target != pattern[1]:
- return False
- # make sure node.args[0] output is only used by current node.
- if len(node.args[0].users) > 1:
- return False
- return True
- class FakeTensorUpdater:
- """
- The main idea here is that it's difficult to maintain accurate fake
- tensors (our primary form of metadata) for each node in our graph as we
- transform it.
- The most reliable way to obtain this information is by rerunning
- faketensor propagation. However, in general, faketensor propagation is
- fairly expensive. So, instead we'd like to only rerun faketensor
- propagation on nodes that have changed.
- In order to detect which nodes have changed, we first hash its node,
- target, and argument lists (which are immutable in FX).
- Then, whenever we call incremental_update, we check which FX nodes have a
- new hash, and recompute the faketensor metadata for that node. Then, we
- continue to recursively compute the faketensors for all users until the
- fake tensors stop changing.
- """
- def __init__(self, graph: torch.fx.Graph):
- self.processed_hashes = set()
- self.graph = graph
- for node in self.graph.nodes:
- self.processed_hashes.add(self.hash_node(node))
- def hash_node(self, node: torch.fx.Node):
- # todo(chilli): Not a great hash function
- return (node, node.target, id(node.args), id(node.kwargs))
- def incremental_update(self):
- processed = set()
- existing_storages: DefaultDict[Optional[int], int] = defaultdict(int)
- for node in self.graph.nodes:
- existing_storages[get_node_storage(node)] += 1
- def is_intlist_same(new, old):
- return statically_known_true(sym_eq(new, old))
- def is_fake_tensor_same(new, old):
- if type(new) != type(old):
- return False
- if isinstance(new, (list, tuple)):
- if len(new) != len(old):
- return False
- return all(
- is_fake_tensor_same(new_i, old_i) for new_i, old_i in zip(new, old)
- )
- if new is None:
- return old is None
- if not isinstance(new, torch.Tensor):
- assert isinstance(
- new, (torch.SymInt, torch.SymBool, torch.SymFloat)
- ), f"Unknown type {type(new)} in {self.graph}"
- return (
- new.node.shape_env._maybe_evaluate_static(
- sympy.Eq(new.node.expr, old.node.expr)
- )
- == sympy.true
- )
- if not is_intlist_same(new.shape, old.shape) or new.layout != old.layout:
- return False
- if new.layout == torch.strided and (
- not is_intlist_same(new.stride(), old.stride())
- or not statically_known_true(
- new.storage_offset() == old.storage_offset()
- )
- ):
- return False
- if new.device != old.device:
- return False
- if get_storage(new) == get_storage(old):
- return True
- # This is the case where it returns a completely fresh storage that's used nowhere else.
- if (
- existing_storages[get_storage(old)] == 1
- and get_storage(new) not in existing_storages
- ):
- return True
- return False
- def should_process_node(node):
- # node.target for nodes returning true from this function
- # are called under fake mode and does not work for inductor
- # lowerings. We check if the node.target is an aten operator
- # or operator.getitem which is used when returning multiple
- # tensors from an op.
- return node.op == "call_function" and (
- isinstance(node.target, torch._ops.OpOverload)
- or node.target == operator.getitem
- )
- to_process = set()
- for node in self.graph.nodes:
- if (
- self.hash_node(node) in self.processed_hashes
- and id(node) not in to_process
- ):
- continue
- if not should_process_node(node):
- continue
- is_valid, args, kwargs = get_fake_args_kwargs(node)
- if not is_valid:
- continue
- with V.fake_mode:
- new_fake_tensor = node.target(*args, **kwargs)
- if "val" in node.meta and is_fake_tensor_same(
- new_fake_tensor, node.meta["val"]
- ):
- continue
- rebind_unbacked(V.fake_mode.shape_env, node, new_fake_tensor)
- node.meta["val"] = new_fake_tensor
- if (shape_env := V.fake_mode.shape_env) and (
- symbol_to_path := compute_unbacked_bindings(shape_env, new_fake_tensor)
- ):
- # Refresh the bindings to the new symbols
- node.meta["unbacked_bindings"] = symbol_to_path
- existing_storages[get_node_storage(node)] += 1
- to_process.update([id(user) for user in node.users])
- self.processed_hashes.add(self.hash_node(node))
- def get_storage(t: torch.Tensor) -> int:
- return t.untyped_storage()._cdata
- def get_node_storage(node: torch.fx.Node) -> Optional[int]:
- if "val" not in node.meta:
- return None
- if not isinstance(node.meta["val"], torch.Tensor):
- return None
- if not torch._C._has_storage(node.meta["val"]):
- return None
- return get_storage(node.meta["val"])
- def get_fake(x):
- if isinstance(x, torch.fx.Node):
- if "val" not in x.meta:
- return x
- return x.meta["val"]
- return x
- def get_fake_args_kwargs(x: torch.fx.Node) -> Tuple[bool, Tuple[Any], Dict[str, Any]]:
- """
- First value returns a boolean if any of the input nodes don't have a faketensor.
- """
- args, kwargs = tree_map(get_fake, (x.args, x.kwargs))
- if any(
- isinstance(a, torch.fx.Node) for a in pytree.arg_tree_leaves(*args, **kwargs)
- ):
- return False, args, kwargs
- return True, args, kwargs
- def is_node_realized(node: torch.fx.Node) -> bool:
- """Returns true if a node is always realized when lowered to inductor IR.
- NOTE: This may return some false negatives. e.g. it doesn't
- handle buffers realized heuristically during lowering, or
- buffers realized indirectly through view ops.
- """
- from torch._inductor.lowering import fallbacks, needs_realized_inputs
- def is_buffer(node: torch.fx.Node) -> bool:
- if node.op == "call_function" and node.target is operator.getitem:
- # For nodes with multiple outputs, we get the fx graph:
- # foo = torch.ops.aten.foo(...)
- # getitem = foo[0]
- # getitem_1 = foo[1]
- # where we need to check if foo is a fallback kernel
- return is_buffer(node.args[0]) # type: ignore[arg-type]
- return node.op in ("placeholder", "output") or node.target in fallbacks
- if is_buffer(node):
- return True
- def realizes_inputs(node: torch.fx.Node) -> bool:
- return node.op == "output" or node.target in needs_realized_inputs
- if any(realizes_inputs(user) for user in node.users):
- return True
- # Otherwise, assume node isn't realized
- return False
|