compile_utils.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. # mypy: ignore-errors
  2. from typing import Callable
  3. import torch
  4. import torch.fx as fx
  5. from torch.utils import _pytree as pytree
  6. from torch.utils._pytree import tree_flatten
  7. aten = torch.ops.aten
  8. def get_aten_target(node: fx.Node) -> Callable:
  9. if hasattr(node.target, "overloadpacket"):
  10. return node.target.overloadpacket
  11. return node.target
  12. rand_ops = [
  13. aten.dropout,
  14. aten._fused_dropout,
  15. aten._standard_gamma,
  16. aten.bernoulli,
  17. aten.multinomial,
  18. aten.native_dropout,
  19. aten.normal,
  20. aten.poisson,
  21. aten.binomial,
  22. aten.rrelu,
  23. aten.rand_like,
  24. aten.rand,
  25. aten.randint,
  26. aten.randn,
  27. aten.randperm,
  28. ]
  29. # return a new copy of torch.fx.graph.Graph with CSE applied to the input graph
  30. def fx_graph_cse(fx_g: torch.fx.graph.Graph):
  31. new_graph = fx.Graph()
  32. env = {} # map from node in the old graph to node in the new graph
  33. hash_env = {} # map from hash to a node in the new graph
  34. token_map = {} # map from hash to token
  35. for n in fx_g.nodes:
  36. # The placeholder, output, and get_attr nodes are copied to the new graph without change
  37. # do not CSE away random operations
  38. if (
  39. n.op == "placeholder"
  40. or n.op == "output"
  41. or n.op == "get_attr"
  42. or get_aten_target(n) in rand_ops
  43. ):
  44. new_node = new_graph.node_copy(n, lambda x: env[x])
  45. env[n] = new_node
  46. else: # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method'
  47. # substitute args and kwargs members to their mapping in env if exists
  48. # specs can be used to reconstruct nested list/dictionaries
  49. def substitute(arg_list):
  50. arg_list, spec = tree_flatten(arg_list)
  51. for i in range(len(arg_list)):
  52. v = arg_list[i]
  53. if isinstance(v, torch.fx.node.Node) and v in env:
  54. arg_list[i] = env[v]
  55. if isinstance(v, (torch.SymBool, torch.SymInt, torch.SymFloat)):
  56. arg_list[i] = v.node
  57. return tuple(arg_list), spec
  58. args, args_spec = substitute(n.args)
  59. kwargs, kwargs_spec = substitute(n.kwargs)
  60. # each token corresponds to a unique node
  61. # nodes with the same token can be substituted
  62. token = {
  63. "target": n.target,
  64. "args": args,
  65. "args_spec": args_spec,
  66. "kwargs": kwargs,
  67. "kwargs_spec": kwargs_spec,
  68. }
  69. # hash substituted args to a number, do not hash specs because specs are not hashable
  70. # We need to add type into hash to avoid situations like:
  71. # hash((primals_2, 1.0)) == hash((primals_2, 1))
  72. hash_arg = hash(
  73. (tuple((a, type(a)) for a in args), tuple((a, type(a)) for a in kwargs))
  74. )
  75. hash_val = (n.target, hash_arg)
  76. # check if a node has a substitute and can be eliminated
  77. hash_val_in_hash_env = hash_val in hash_env
  78. if hash_val_in_hash_env and token_map[hash_val] == token:
  79. env[n] = hash_env[hash_val]
  80. continue
  81. new_node = new_graph.node_copy(n, lambda x: env[x])
  82. env[n] = new_node
  83. if not hash_val_in_hash_env:
  84. hash_env[hash_val] = new_node
  85. token_map[hash_val] = token
  86. return new_graph
  87. def strip_overloads(gm):
  88. """
  89. Modifies the target of graph nodes in :attr:`gm` to strip overloads.
  90. Args:
  91. gm(fx.GraphModule): The input Fx graph module to be modified
  92. """
  93. for node in gm.graph.nodes:
  94. if isinstance(node.target, torch._ops.OpOverload):
  95. node.target = node.target.overloadpacket
  96. gm.recompile()
  97. def get_placeholders(graph):
  98. return graph.find_nodes(op="placeholder")
  99. def get_outputs(graph):
  100. for node in graph.find_nodes(op="output"):
  101. return pytree.tree_leaves(node.args[0])
  102. raise AssertionError("No output node found")