graph.py 71 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768
  1. # mypy: allow-untyped-defs
  2. import functools
  3. import itertools
  4. import logging
  5. import operator
  6. import os
  7. import re
  8. import sys
  9. import time
  10. from collections import defaultdict
  11. from contextlib import contextmanager
  12. from typing import (
  13. Any,
  14. Callable,
  15. DefaultDict,
  16. Dict,
  17. List,
  18. Optional,
  19. Set,
  20. Tuple,
  21. TYPE_CHECKING,
  22. Union,
  23. )
  24. import sympy
  25. import torch
  26. import torch._logging
  27. import torch.fx
  28. from torch._decomp import get_decompositions
  29. from torch._dynamo.utils import defake, dynamo_timed
  30. from torch._logging import LazyString, trace_structured
  31. from torch._prims_common import make_channels_last_strides_for
  32. from torch._subclasses.fake_tensor import FakeTensor
  33. from torch.fx.experimental._backward_state import BackwardState
  34. from torch.fx.experimental.sym_node import magic_methods, method_to_operator
  35. from torch.fx.experimental.symbolic_shapes import (
  36. free_unbacked_symbols,
  37. has_free_symbols,
  38. resolve_unbacked_bindings,
  39. RuntimeAssert,
  40. ShapeEnv,
  41. SymTypes,
  42. )
  43. from torch.utils._mode_utils import no_dispatch
  44. from . import config, ir
  45. from .codegen.common import (
  46. DeviceOpOverrides,
  47. get_device_op_overrides,
  48. get_scheduling_for_device,
  49. get_wrapper_codegen_for_device,
  50. register_backend_for_device,
  51. )
  52. from .codegen.cpp_wrapper_cpu import CppWrapperCpu
  53. from .codegen.cpp_wrapper_cuda import CppWrapperCuda
  54. from .codegen.wrapper import WrapperCodeGen
  55. from .exc import (
  56. CppWrapperCodeGenError,
  57. LoweringException,
  58. MissingOperatorWithDecomp,
  59. MissingOperatorWithoutDecomp,
  60. )
  61. from .ir import (
  62. Constant,
  63. FixedLayout,
  64. InputBuffer,
  65. Pointwise,
  66. Reduction,
  67. StorageBox,
  68. TensorBox,
  69. TorchBindObject,
  70. )
  71. from .lowering import (
  72. constrain_to_fx_strides,
  73. FALLBACK_ALLOW_LIST,
  74. fallback_handler,
  75. fallback_node_due_to_unsupported_type,
  76. layout_constraints,
  77. lowerings,
  78. make_fallback,
  79. needs_realized_inputs,
  80. unsupported_output_tensor,
  81. )
  82. from .sizevars import SizeVarAllocator
  83. from .utils import (
  84. convert_shape_to_inductor,
  85. gather_origins,
  86. get_cloned_parameter_buffer_name,
  87. get_sympy_Expr_dtype,
  88. maybe_get_suppress_shape_guards_ctx,
  89. should_assume_input_aligned,
  90. )
  91. from .virtualized import NullHandler, V
  92. if TYPE_CHECKING:
  93. from torch._higher_order_ops.effects import _EffectType
  94. log = logging.getLogger(__name__)
  95. perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
  96. output_code_log = torch._logging.getArtifactLogger(__name__, "output_code")
  97. aten = torch.ops.aten
  98. _post_grad_graph_counter = itertools.count()
  99. if config.is_fbcode():
  100. from torch._inductor.fb.utils import log_module_code
  101. else:
  102. def log_module_code(*args, **kwargs):
  103. pass
  104. def supported_dtype_of_cpp_wrapper(dtype, cuda):
  105. supported_dtype = {
  106. torch.float32,
  107. torch.float64,
  108. torch.int64,
  109. torch.int32,
  110. torch.int16,
  111. torch.int8,
  112. torch.uint8,
  113. torch.bool,
  114. torch.bfloat16,
  115. torch.complex32,
  116. torch.complex64,
  117. torch.complex128,
  118. torch.float16,
  119. }
  120. if cuda:
  121. supported_dtype.add(torch.float8_e4m3fn)
  122. supported_dtype.add(torch.float8_e5m2)
  123. supported_dtype.add(torch.float8_e4m3fnuz)
  124. supported_dtype.add(torch.float8_e5m2fnuz)
  125. return dtype in supported_dtype
  126. def may_get_constant_buffer_dtype(constant_buffer):
  127. assert isinstance(
  128. constant_buffer, (sympy.Symbol, sympy.Expr, sympy.core.numbers.Integer)
  129. ), "get_constant_buffer_dtype only supports input of sympy.Symbol, sympy.Expr or sympy.core.numbers.Integer"
  130. if isinstance(constant_buffer, sympy.core.numbers.Integer):
  131. return torch.int64
  132. if isinstance(constant_buffer, sympy.Expr):
  133. return get_sympy_Expr_dtype(constant_buffer)
  134. if constant_buffer.is_integer:
  135. return torch.int64
  136. elif constant_buffer.is_float:
  137. return torch.float32
  138. else:
  139. return None
  140. def is_magic_method(op):
  141. magic_ops = {method_to_operator(m) for m in magic_methods}
  142. return op in magic_ops
  143. def getattr_recursive(obj, target):
  144. target_atoms = target.split(".")
  145. attr_itr = obj
  146. for i, atom in enumerate(target_atoms):
  147. if not hasattr(attr_itr, atom):
  148. raise RuntimeError(
  149. f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}"
  150. )
  151. attr_itr = getattr(attr_itr, atom)
  152. return attr_itr
  153. def mark_nodes_dislike_padding(g):
  154. """
  155. Nodes like convolution/convolution_backward want its input to be dense.
  156. If we pad their inputs, we result in extra calls to copy kernels! On the other hand, padding usually helps reduction.
  157. The pass finds nodes that dislike padding. These are nodes that can be reached
  158. from a convolution/convolution_backward in the backward direction without
  159. going thru a reduction.
  160. """
  161. if not config.comprehensive_padding:
  162. return
  163. ops_dislike_padding = {
  164. aten.convolution,
  165. aten.convolution_backward,
  166. }
  167. # what's a better way to collect the reduction ops?
  168. ops_like_padding = {
  169. aten.var_mean,
  170. aten.sum,
  171. aten.mean,
  172. aten.prod,
  173. aten.any,
  174. aten.amin,
  175. aten.amax,
  176. aten.min,
  177. aten.max,
  178. aten.argmin,
  179. aten.argmax,
  180. aten.scatter_reduce,
  181. }
  182. def _get_overload_packet(node):
  183. return (
  184. node.target._overloadpacket
  185. if node.op == "call_function" and hasattr(node.target, "_overloadpacket")
  186. else None
  187. )
  188. for cur in reversed(g.nodes):
  189. op = _get_overload_packet(cur)
  190. if not op:
  191. continue
  192. if op in ops_dislike_padding:
  193. cur.meta["dislike_padding"] = True
  194. if cur.meta.get("dislike_padding", False):
  195. # propagate
  196. for prior in cur.all_input_nodes:
  197. prior_op = _get_overload_packet(prior)
  198. if not prior_op:
  199. continue
  200. if prior_op not in ops_like_padding:
  201. prior.meta["dislike_padding"] = True
  202. class GraphLowering(torch.fx.Interpreter):
  203. graph_outputs: List[ir.IRNode]
  204. def symbolic_sizes_strides(self, ex: torch.Tensor):
  205. """
  206. Support dynamic shapes and dynamic strides by assigning variables
  207. to each dimension. We duck-shape tensors, so if two tensors
  208. have the same size they get assigned the same symbolic variable.
  209. """
  210. if self.reuse_shape_env:
  211. return convert_shape_to_inductor(ex.size()), convert_shape_to_inductor(
  212. ex.stride()
  213. )
  214. else:
  215. from torch._dynamo.source import ConstantSource
  216. # TODO: this should not be needed once #93059 lands
  217. # https://github.com/pytorch/pytorch/pull/94031#discussion_r1096044816
  218. # TODO: make a dedicated UnknownSource for this?
  219. # NB: This is using the legacy default behavior from
  220. # create_symbolic_sizes_strides_storage_offset but we hope we can
  221. # just delete this entirely
  222. source = ConstantSource(
  223. f"__inductor_unknown_tensor_{len(self._shape_env.var_to_val)}"
  224. )
  225. (
  226. size,
  227. stride,
  228. _,
  229. ) = self._shape_env.create_symbolic_sizes_strides_storage_offset(
  230. ex,
  231. source,
  232. )
  233. size = [i.node.expr if isinstance(i, torch.SymInt) else i for i in size]
  234. stride = [i.node.expr if isinstance(i, torch.SymInt) else i for i in stride]
  235. return size, stride
  236. def static_sizes_strides(self, ex: torch.Tensor):
  237. """
  238. Primarily used to weights
  239. """
  240. size = [sympy.Integer(i) for i in ex.size()]
  241. stride = [sympy.Integer(i) for i in ex.stride()]
  242. return size, stride
  243. def init_backend_registration(self):
  244. if get_scheduling_for_device("cpu") is None:
  245. from .codegen.cpp import CppScheduling
  246. register_backend_for_device(
  247. "cpu", CppScheduling, WrapperCodeGen, CppWrapperCpu
  248. )
  249. if get_scheduling_for_device("cuda") is None:
  250. from .codegen.cuda_combined_scheduling import CUDACombinedScheduling
  251. # CUDACombinedScheduling combines Triton and CUDA C++ scheduling for CUDA devices via delegation
  252. register_backend_for_device(
  253. "cuda", CUDACombinedScheduling, WrapperCodeGen, CppWrapperCuda
  254. )
  255. if get_scheduling_for_device("xpu") is None:
  256. from .codegen.triton import TritonScheduling
  257. register_backend_for_device("xpu", TritonScheduling, WrapperCodeGen)
  258. def __init__(
  259. self,
  260. gm: torch.fx.GraphModule,
  261. example_inputs: Optional[List[torch.Tensor]] = None,
  262. shape_env=None,
  263. graph_id=None,
  264. cpp_wrapper=False,
  265. aot_mode=False,
  266. user_visible_outputs=None,
  267. layout_opt=None,
  268. extern_node_serializer=None,
  269. is_inference=False,
  270. is_const_graph=False,
  271. const_output_index=None,
  272. const_code=None,
  273. const_module=None,
  274. name=None,
  275. ):
  276. super().__init__(gm)
  277. self.example_inputs = example_inputs
  278. self.layout_opt = (
  279. layout_opt
  280. if layout_opt is not None
  281. else self.decide_layout_opt(gm, is_inference=is_inference)
  282. )
  283. self.num_channels_last_conv = 0
  284. self.is_inference = is_inference
  285. self.is_const_graph = is_const_graph
  286. self.const_code = const_code
  287. self.const_module = const_module
  288. self.extra_traceback = False # we do our own error wrapping
  289. if shape_env is None:
  290. shape_env = ShapeEnv()
  291. self.reuse_shape_env = False
  292. else:
  293. self._shape_env = shape_env
  294. self.reuse_shape_env = True
  295. self._shape_env = shape_env
  296. # We are going to start code generating runtime asserts, so make sure
  297. # you don't start adding new ones in the lowering process
  298. shape_env.freeze_runtime_asserts()
  299. # We're going to mutate ras_by_symbol as we finish generating them
  300. self.ras_by_symbol: Dict[
  301. sympy.Symbol, List[RuntimeAssert]
  302. ] = shape_env.deferred_runtime_asserts.copy()
  303. self.bound_unbacked_symbols: Set[sympy.Symbol] = set()
  304. self.sizevars = SizeVarAllocator(shape_env)
  305. self.graph_input_names: List[str] = []
  306. self.graph_inputs: Dict[str, TensorBox] = {}
  307. self.graph_inputs_original: Dict[str, InputBuffer] = {}
  308. self.device_types: Set[str] = (
  309. const_module.device_types if const_module else set()
  310. )
  311. self.device_idxs: Set[int] = const_module.device_idxs if const_module else set()
  312. self.cuda = False
  313. self.buffers: List[ir.Buffer] = []
  314. self.const_output_index: Dict[str, int] = (
  315. const_output_index if const_output_index else {}
  316. )
  317. self.folded_constants: Set[str] = (
  318. set(const_output_index.keys()) if const_output_index else set()
  319. )
  320. self.constants: Dict[str, torch.Tensor] = (
  321. const_module.constants if const_module else {}
  322. )
  323. self.torchbind_constants: Dict[str, torch._C.ScriptObject] = {}
  324. self.constant_reprs: Dict[str, str] = {}
  325. self.removed_buffers: Set[str] = set()
  326. self.removed_inplace_buffers: Set[str] = set()
  327. self.mutated_buffers: Set[str] = set()
  328. self.never_reuse_buffers: Set[str] = set()
  329. self.inplaced_to_remove: Set[str] = set()
  330. self.device_ops: DeviceOpOverrides = None # type: ignore[assignment]
  331. self.wrapper_code: WrapperCodeGen = None # type: ignore[assignment]
  332. # See `ProxyExecutor Design Note` in ir.py for more details
  333. self.extern_kernel_nodes: List[ir.ExternKernelNode] = []
  334. self.extern_node_serializer: Optional[
  335. Callable[[List[ir.ExternKernelNode]], Any]
  336. ] = extern_node_serializer
  337. self.current_node: torch.fx.Node = None # type: ignore[assignment]
  338. self.lists: Dict[str, List[str]] = {}
  339. self.mutated_inputs: Set[str] = set()
  340. self.mutated_input_idxs: List[int] = []
  341. self.name_to_buffer: Dict[str, ir.Buffer] = {}
  342. self.name_to_users: DefaultDict[str, List[ir.IRNode]] = defaultdict(list)
  343. self.creation_time = time.time()
  344. self.name = name
  345. self.cpp_wrapper = cpp_wrapper
  346. # record multi_kernel choice for cpp_wrapper so the second pass knows
  347. # which sub-kernel is picked. Copy cpp_wrapper to another variable
  348. # since cpp_wrapper flag is set to false for the first pass of codegen.
  349. self.record_multi_kernel_choice = cpp_wrapper
  350. self.multi_kernel_to_choice: Dict[str, int] = {}
  351. self.aot_mode = aot_mode
  352. self.graph_id = graph_id
  353. self.post_grad_graph_id = next(_post_grad_graph_counter)
  354. self.scheduler: torch._inductor.scheduler.Scheduler = None # type: ignore[assignment]
  355. self.nodes_prefer_channels_last = (
  356. self.find_nodes_prefer_channels_last() if self.layout_opt else set()
  357. )
  358. mark_nodes_dislike_padding(gm.graph)
  359. self._warned_fallback = {"aten.convolution_backward"}
  360. self.user_visible_outputs = (
  361. user_visible_outputs if user_visible_outputs is not None else {}
  362. )
  363. self.cache_key: str = "" # This is the cache key for the compiled artifact
  364. self.cache_path: str = "" # This is the path in the filesystem where the compiled artifact is stored
  365. self.cache_linemap: List[
  366. Tuple[int, str]
  367. ] = (
  368. []
  369. ) # This is the linemap used by the profiler to mark custom compiled kernels getting run
  370. # Used if lowering encounters cases where cudagraphs are not supported
  371. self.disable_cudagraphs_reason: Optional[str] = None
  372. # only keeping one node per device for stack trace purposes
  373. self.device_node_mapping: Dict[torch.device, torch.fx.Node] = {}
  374. self.orig_gm: torch.fx.GraphModule = gm.__copy__()
  375. self.dynamo_flat_name_to_original_fqn = self.module.meta.get(
  376. "dynamo_flat_name_to_original_fqn", {}
  377. )
  378. self.allocated_constant_name = (
  379. const_module.allocated_constant_name if const_module is not None else {}
  380. )
  381. self.init_backend_registration()
  382. self.effectful_ops: Dict[_EffectType, ir.Buffer] = {}
  383. self.aligned_inputs: Set[str] = set()
  384. @staticmethod
  385. def decide_layout_opt(gm, *, is_inference) -> bool:
  386. """
  387. Decide if we should enable layout optimization for this graph based on
  388. heuristics.
  389. """
  390. if not config.layout_optimization:
  391. return False
  392. if config.force_layout_optimization:
  393. return True
  394. conv_nodes = [
  395. n for n in gm.graph.nodes if n.target == torch.ops.aten.convolution.default
  396. ]
  397. nconv = len(conv_nodes)
  398. if nconv == 0:
  399. return False
  400. # For cpu backend and mkldnn enabled, we always use channels_last for better performance.
  401. if (
  402. torch.backends.mkldnn.enabled
  403. and torch.backends.mkldnn.is_available()
  404. and all(
  405. n.args[idx].meta["val"].device == torch.device("cpu")
  406. for n in conv_nodes
  407. for idx in [0, 1]
  408. )
  409. ):
  410. return True
  411. # Following models are skipped due to this:
  412. # jx_nest_base
  413. # volo_d1_224
  414. if len(list(gm.graph.nodes)) >= 300 * nconv:
  415. log.debug("Skipped layout opt because only a few conv")
  416. return False
  417. if any(
  418. has_free_symbols(n.args[idx].meta["val"])
  419. for n in conv_nodes
  420. for idx in [0, 1]
  421. ):
  422. log.debug(
  423. "See perf regression with dynamic shape. Follow up in https://github.com/pytorch/pytorch/issues/102670"
  424. )
  425. return False
  426. def is_grouped(n):
  427. return n.args[-1] > 1 and n.args[1].meta["val"].size(1) > 1
  428. def is_in_out_channel(n):
  429. return (
  430. n.args[1].meta["val"].size(0) * 2 <= n.args[1].meta["val"].size(1)
  431. and n.args[1].meta["val"].size(2) > 1
  432. )
  433. def is_small_channel(n):
  434. return (
  435. n.args[1].meta["val"].size(0) <= 64
  436. and n.args[1].meta["val"].size(1) <= 64
  437. )
  438. # only grouped convolutions benchmarked as slower in conv samples for inference only
  439. if is_inference:
  440. from torch.utils.flop_counter import FlopCounterMode
  441. flop_counts: Dict[str, float] = defaultdict(float)
  442. for node in conv_nodes:
  443. success, args, kwargs = torch._inductor.fx_utils.get_fake_args_kwargs(
  444. node
  445. )
  446. if success:
  447. with FlopCounterMode(display=False) as flop_counter_mode:
  448. with V.fake_mode:
  449. node.target(*args, **kwargs)
  450. counted_flops = flop_counter_mode.get_total_flops()
  451. if is_grouped(node):
  452. node_type = "grouped"
  453. elif is_small_channel(node):
  454. node_type = "small"
  455. elif is_in_out_channel(node):
  456. node_type = "in_out"
  457. else:
  458. node_type = "default"
  459. flop_counts[node_type] += counted_flops
  460. else:
  461. log.debug("Conv inputs meta not found")
  462. # average benchmarked channels last speedup / slowdown, < 1 is speedup.
  463. # taken from the set of convolution inputs in benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/
  464. # To regenerate these numbers follow https://gist.github.com/eellison/55d7a6ed6f39829d68ac56f95f4df5bb
  465. GROUPED_MULTIPLIER = 1.358
  466. DEFAULT_MULTIPLIER = 0.823
  467. IN_OUT_MULTIPLIER = 0.725
  468. SMALL_MULTIPLIER = 0.783
  469. total_flops = sum(flop_counts.values())
  470. # TODO - get different values per hardware
  471. weighted_flops = (
  472. flop_counts["grouped"] * GROUPED_MULTIPLIER
  473. + flop_counts["small"] * SMALL_MULTIPLIER
  474. + flop_counts["in_out"] * IN_OUT_MULTIPLIER
  475. + flop_counts["default"] * DEFAULT_MULTIPLIER
  476. )
  477. do_layout_opt = weighted_flops <= total_flops
  478. if not do_layout_opt:
  479. log.debug(
  480. "Skipped layout opt in inference because weighted flops indicate slowdown, default: %d, channels last: %d",
  481. total_flops,
  482. weighted_flops,
  483. )
  484. return do_layout_opt
  485. # Channels last layout can dramatically hurt grouped conv perf. E.g.
  486. # Conv with arguments like
  487. # {"input_shape": [32, 224, 112, 112], "weight_shape": [224, 112, 3, 3],
  488. # "stride": [2, 2], "padding": [1, 1], "groups": 2}
  489. # slows down 31x using channels last..
  490. # But a lot of timm models use depthwise separable convolution which will
  491. # result in grouped convolution with in-channel size == 1.
  492. # For those grouped convolution, channels last still helps a lot.
  493. # E.g.
  494. # Conv with arguments
  495. # {"input_shape": [128, 58, 56, 56], "weight_shape": [58, 1, 3, 3],
  496. # "stride": [2, 2], "padding": [1, 1], "groups": 58}
  497. # get 1.86x speedup with channels last layout.
  498. #
  499. # The following heuristics skip using channels-last if the model contains
  500. # grouped convolution with in-channels > 1.
  501. if any(map(is_grouped, conv_nodes)):
  502. log.debug(
  503. "Skip layout opt because found grouped convolution with >1 in_channels!"
  504. )
  505. return False
  506. # For some models that contain convolution with larger in-channel than out-channel, applying
  507. # channels last hurts performance.
  508. # Following models are skipped due to this:
  509. # - pytorch_unet
  510. # - phlippe_densenet (slightly worse)
  511. # - Background_Matting (1.22x -> 0.821x)
  512. # - pytorch_CycleGAN_and_pix2pix (1.597x -> 1.294x)
  513. if any(map(is_in_out_channel, conv_nodes)):
  514. log.debug(
  515. "Skip layout opt because some convolutions have smaller out_channel"
  516. )
  517. return False
  518. # Following models are skipped due to this:
  519. # - functorch_maml_omniglot
  520. if all(map(is_small_channel, conv_nodes)):
  521. log.debug("Skip layout opt because all convolution channels are too small")
  522. return False
  523. return True
  524. def qualify_name(self, name: str) -> str:
  525. """Prepend the given name with the graph name if any."""
  526. if self.name is not None:
  527. return f"{self.name}_{name}"
  528. return name
  529. def make_subgraph(
  530. self,
  531. gm: torch.fx.GraphModule,
  532. example_inputs: List[torch.Tensor],
  533. subgraph_name: str,
  534. ) -> "GraphLowering":
  535. """
  536. Make a subgraph of the current graph with all inherited
  537. parts, except the graph module (`gm`) and `example_inputs`.
  538. The subgraphs are lowered separately, but intended to be
  539. inlined in the parent graph's codegening. Hence the need
  540. for maintaining the same `shape_env` and other properties.
  541. The subgraph name is qualified by the parent graph's name.
  542. """
  543. return GraphLowering(
  544. gm=gm,
  545. example_inputs=example_inputs,
  546. shape_env=self._shape_env,
  547. cpp_wrapper=self.cpp_wrapper,
  548. aot_mode=self.aot_mode,
  549. extern_node_serializer=self.extern_node_serializer,
  550. is_inference=self.is_inference,
  551. name=self.qualify_name(subgraph_name),
  552. )
  553. def find_nodes_prefer_channels_last(self):
  554. """
  555. The rule to decide if an node prefer channels last is simple.
  556. 1. if it's input/output of a convolution
  557. 2. if one of its user prefers channels last
  558. We have rule 1 because cudnn runs a faster convolution kernel for channels last inputs;
  559. Rule 2 is also important. It makes sure that indirect inputs to convolution also prefers
  560. channels last.
  561. Consider the scenario: conv -> batch-norm -> relu -> conv
  562. Without rule 2, batch-norm output may use a contiguous layout. That will cause 2 extra copies:
  563. 1. the output of batch-norm should be channels last initially since its input is a conv's output.
  564. Forcing the batch-norm's output to be contiguous results in the first copy
  565. 2. The second conv's input is initially contiguous. This layout is propagated from the batch-norm's output.
  566. We need convert it to channels last layout which results in the second copy.
  567. With rule 2, we makes sure all the tensors in the chain uses channels last layout. So both copies
  568. can be saved.
  569. """
  570. output_set = set()
  571. for n in reversed(self.module.graph.nodes):
  572. if n.target == torch.ops.aten.convolution.default:
  573. output_set.add(n)
  574. continue
  575. for user in n.users:
  576. if user in output_set:
  577. output_set.add(n)
  578. break
  579. # need a second pass to add downstream nodes of those channel last nodes to the sets.
  580. # This pass is especially needed to avoid mix-layout kernel inputs in backward pass.
  581. #
  582. # Let's say a conv-batchnorm 's output is passed to relu whose output is in turn returned
  583. # from the fwd graph. Without this second pass, we will force relu's output to be contiguous.
  584. # Then in the kernel in backward pass, the contiguous output of relu may be mix with other channels last
  585. # tensors and passed to a kernel.
  586. #
  587. # This pass improve yolov3 training speedup from 1.116x (worse than disabling layout optimization speedup 1.196x) to 1.457x.
  588. # It also improves dla102 training speedup from 1.240x (worse than disabling layout optimization speedup 1.523x) to 1.835x .
  589. # This also helps the following models:
  590. # - res2net101_26w_4s
  591. # - res2net50_14w_8s
  592. # - sebotnet33ts_256
  593. for n in self.module.graph.nodes:
  594. if n in output_set:
  595. output_set.update(n.users)
  596. return output_set
  597. def warn_fallback(self, name):
  598. if name not in self._warned_fallback:
  599. self._warned_fallback.add(name)
  600. perf_hint_log.info("Using FallbackKernel: %s", name)
  601. def add_device_info(self, device: torch.device):
  602. self.device_types.add(device.type)
  603. if device.index is not None:
  604. self.device_idxs.add(device.index)
  605. if V.graph.current_node and device not in self.device_node_mapping:
  606. self.device_node_mapping[device] = V.graph.current_node
  607. @property
  608. def fake_mode(self):
  609. return V.fake_mode
  610. def get_buffer(self, buffer_name: str):
  611. if buffer_name in self.name_to_buffer:
  612. return self.name_to_buffer[buffer_name]
  613. if buffer_name in self.graph_inputs:
  614. return self.graph_inputs[buffer_name]
  615. if buffer_name in self.constants:
  616. data = V.graph.constants[buffer_name]
  617. return ir.ConstantBuffer(
  618. buffer_name,
  619. ir.FixedLayout(
  620. data.device, data.dtype, *V.graph.static_sizes_strides(data)
  621. ),
  622. )
  623. return None
  624. def get_dtype(self, buffer_name: str):
  625. if buffer_name in self.constants:
  626. return self.constants[buffer_name].dtype
  627. if buffer_name in self.name_to_buffer:
  628. return self.name_to_buffer[buffer_name].get_dtype()
  629. if buffer_name in self.graph_inputs:
  630. return self.graph_inputs[buffer_name].get_dtype()
  631. m = re.match(r"(as_strided|reinterpret_tensor)\(([a-zA-Z0-9_]+),", buffer_name)
  632. if m:
  633. return self.get_dtype(m.group(1))
  634. raise KeyError(f"could not find {buffer_name}")
  635. def get_numel(self, buffer_name: str):
  636. from .ir import MultiOutputLayout
  637. if buffer_name in self.constants:
  638. return self.constants[buffer_name].numel()
  639. if buffer_name in self.name_to_buffer:
  640. buf = self.name_to_buffer[buffer_name]
  641. if isinstance(getattr(buf, "layout", None), MultiOutputLayout):
  642. return 1
  643. return buf.get_numel()
  644. if buffer_name in self.graph_inputs:
  645. return self.graph_inputs[buffer_name].get_numel()
  646. raise KeyError(f"could not find {buffer_name}")
  647. @dynamo_timed
  648. def run(self, *args):
  649. return super().run(*args)
  650. def register_buffer(self, buffer: ir.Buffer, *, set_name: bool = False):
  651. name = self.qualify_name(f"buf{len(self.buffers)}")
  652. self.buffers.append(buffer)
  653. self.name_to_buffer[name] = buffer
  654. # Skip empty CPU tensor so that CUDA graphs can succeed, see https://github.com/pytorch/pytorch/pull/114144
  655. if (
  656. not (isinstance(buffer, ir.ComputedBuffer) and buffer.is_zero_elements())
  657. and buffer.get_device() is not None
  658. ):
  659. self.add_device_info(buffer.get_device())
  660. if set_name:
  661. buffer.name = name
  662. return name
  663. def register_list(self, buffer_names: List[str]):
  664. name = self.qualify_name("list_" + "_".join(buffer_names))
  665. self.lists[name] = buffer_names
  666. return name
  667. def register_users_of(self, node_output):
  668. def register(value):
  669. if isinstance(value, (list, tuple)):
  670. for x in value:
  671. register(x)
  672. if isinstance(value, ir.IRNode):
  673. if (
  674. not hasattr(value, "data")
  675. or not isinstance(value.data, ir.IRNode)
  676. or not (
  677. hasattr(value.data, "data")
  678. and isinstance(value.data.data, ir.IRNode)
  679. )
  680. ):
  681. return
  682. for read_name in value.get_read_names():
  683. self.name_to_users[read_name].append(value)
  684. register(node_output)
  685. def mark_buffer_mutated(self, name: str):
  686. """
  687. When a buffer is mutated we need to make sure all the reads to
  688. the old version are realized before the mutation happens.
  689. """
  690. assert isinstance(name, str)
  691. self.mutated_buffers.add(name)
  692. if name not in self.name_to_users:
  693. return
  694. for user in self.name_to_users[name]:
  695. user.realize()
  696. def get_original_value_of_constant(self, name: str):
  697. """
  698. In AOTI, module buffers may have been mutated during the tracing and compilation.
  699. Thus we need to read from previously stored original buffers, to make sure the
  700. generated model.so uses correct initial values.
  701. """
  702. assert name in self.allocated_constant_name and name in self.constants, (
  703. "Can not find the original value for " + name
  704. )
  705. orig_name = get_cloned_parameter_buffer_name(self.allocated_constant_name[name])
  706. return (
  707. self.module.meta[orig_name]
  708. if orig_name in self.module.meta
  709. else self.constants[name]
  710. )
  711. def allocate_non_dup_const_name(self, name, data):
  712. orig_name = name
  713. if not config.aot_inductor.use_runtime_constant_folding:
  714. for constant_name, value in self.constants.items():
  715. if (
  716. not data.is_mkldnn
  717. and data.size() == value.size()
  718. and data.stride() == value.stride()
  719. and data.dtype == value.dtype
  720. and data.device == value.device
  721. and data.untyped_storage().data_ptr()
  722. == value.untyped_storage().data_ptr()
  723. and data.storage_offset() == value.storage_offset()
  724. ):
  725. return constant_name
  726. if name is None:
  727. name = f"constant{len(self.constants)}"
  728. if name[0].isdigit():
  729. name = f"constant_{name}"
  730. name = self.qualify_name(name)
  731. # We may generate a var name for each constant in the codegen.
  732. # Let's only keep sane characters.
  733. prefix = re.sub(r"[^a-zA-Z0-9_]", "_", name)
  734. name = prefix
  735. cnt = 0
  736. while name in self.constants:
  737. name = f"{prefix}_{cnt}"
  738. cnt += 1
  739. self.constants[name] = data
  740. self.constant_reprs[name] = (
  741. f"{data.device!r} {data.dtype!r} "
  742. f"{tuple(data.size())!r} {tuple(data.stride())!r} "
  743. f"{hash(data):x}"
  744. )
  745. self.allocated_constant_name[name] = orig_name
  746. return name
  747. def add_tensor_constant(self, data, name=None):
  748. new_name = self.allocate_non_dup_const_name(name, data)
  749. return TensorBox.create(
  750. ir.ConstantBuffer(
  751. new_name,
  752. FixedLayout(data.device, data.dtype, *self.static_sizes_strides(data)),
  753. )
  754. )
  755. def constant_name(self, name: str, device_override: Optional[torch.device]):
  756. """
  757. We AOT copy constants to the devices they are needed on.
  758. If device_override doesn't match the constant's device, then
  759. copy it and return a different name.
  760. """
  761. if self.constants[name].device == device_override or device_override is None:
  762. return name
  763. with torch.utils._python_dispatch._disable_current_modes():
  764. # caller might have set fake tensor mode which will create a fake tensor
  765. # when calling .to, so unset modes here
  766. return self.allocate_non_dup_const_name(
  767. f"{name}_{device_override.type}{device_override.index or 0}",
  768. self.constants[name].to(device_override),
  769. )
  770. def placeholder(self, target: str, args, kwargs):
  771. example = super().placeholder(target, args, kwargs)
  772. self.graph_input_names.append(target)
  773. if isinstance(example, SymTypes):
  774. expr = example.node.expr
  775. self.graph_inputs[target] = expr
  776. return expr
  777. elif isinstance(example, (int, bool, float)):
  778. expr = sympy.sympify(example)
  779. self.graph_inputs[target] = expr
  780. return expr
  781. if isinstance(example, BackwardState):
  782. # Ignored arg, must be unused
  783. # Alternately we could filter this out in AotAutograd
  784. return None
  785. assert isinstance(example, torch.Tensor), example
  786. # todo(chilli): We can remove the last check once we turn buffers into
  787. # static shape tensors. That's a hack to workaround Inductor believing
  788. # the buffer should be static but us passing in a fake tensor with
  789. # symbolic shapes.
  790. if not example._has_symbolic_sizes_strides:
  791. # the first N inputs are weights
  792. sizes, strides = self.static_sizes_strides(example)
  793. else:
  794. sizes, strides = self.symbolic_sizes_strides(example)
  795. # TODO(jansel): handle input aliasing
  796. target = self.qualify_name(target)
  797. tensor = TensorBox.create(
  798. InputBuffer(
  799. target,
  800. FixedLayout(example.device, example.dtype, sizes, strides),
  801. )
  802. )
  803. self.graph_inputs[target] = tensor
  804. self.graph_inputs_original[target] = tensor.data.data
  805. self.add_device_info(example.device)
  806. # Note: [Input Alignment handling in Inductor]
  807. # Alignment matters for generating efficient code. Some operations,
  808. # e.g. vectorized loads, can only be performed on aligned inputs.
  809. #
  810. # But if we codegen assuming aligned inputs and then get unaligned
  811. # inputs at runtime, then we are forced to clone - which is bad for
  812. # both perf and memory usage.
  813. #
  814. # One option would be to guard on storage_offset%ALIGNMENT, and then
  815. # codegen based on this. But storage_offset guards turned out to be
  816. # expensive and cause recompiles; Instead, we're generating code
  817. # based on the alignment of the example input without guarding.
  818. with maybe_get_suppress_shape_guards_ctx():
  819. if should_assume_input_aligned(example):
  820. self.aligned_inputs.add(target)
  821. return tensor
  822. def call_function(self, target, args, kwargs):
  823. if target is operator.getitem and isinstance(args[0], (list, tuple, dict)):
  824. return super().call_function(target, args, kwargs)
  825. if hasattr(target, "_inductor_lowering_function"):
  826. # passthrough lowerings from .pattern_matcher
  827. return target(*args, **kwargs)
  828. def get_custom_op_layout_constraints(target, args, kwargs):
  829. # Custom operations that require preserving stride order
  830. # which run through implicit fallback must constrain their
  831. # arguments' fx strides
  832. layout_constraint = None
  833. if torch._C.Tag.needs_fixed_stride_order in target.tags:
  834. # We have to set the current args because call_function will immediately
  835. # evaluate this lowering after creating the fallback, without evaluating
  836. # the layout constraint
  837. constrain_fn = functools.partial(
  838. constrain_to_fx_strides, ignore_mutated_args_FIXME=True
  839. )
  840. args, kwargs = constrain_fn(self.current_node, *args, **kwargs)
  841. # Also register the layout constraint so when the fallback
  842. # is used again, we can constrain the args to the same layout
  843. layout_constraint = constrain_fn
  844. return layout_constraint, args, kwargs
  845. if target not in lowerings:
  846. assert isinstance(
  847. target, torch._ops.OpOverload
  848. ), f"{target} is not an OpOverload"
  849. base_name = target.name().split(".")[0]
  850. if base_name in FALLBACK_ALLOW_LIST:
  851. make_fallback(target)
  852. elif config.implicit_fallbacks:
  853. layout_constraint, args, kwargs = get_custom_op_layout_constraints(
  854. target, args, kwargs
  855. )
  856. error = (
  857. MissingOperatorWithDecomp
  858. if get_decompositions([target])
  859. else MissingOperatorWithoutDecomp
  860. )
  861. log.info(
  862. "Creating implicit fallback for:\n%s",
  863. error.operator_str(target, args, kwargs),
  864. )
  865. make_fallback(target, layout_constraint)
  866. elif get_decompositions([target]):
  867. # There isn't a good way to dynamically patch this in
  868. # since AOT Autograd already ran. The error message tells
  869. # the user how to fix it.
  870. raise MissingOperatorWithDecomp(target, args, kwargs)
  871. else:
  872. raise MissingOperatorWithoutDecomp(target, args, kwargs)
  873. try:
  874. log.debug(" via %s", lowerings[target])
  875. out = lowerings[target](*args, **kwargs)
  876. return out
  877. except Exception as e:
  878. raise LoweringException(e, target, args, kwargs).with_traceback(
  879. e.__traceback__
  880. ) from None
  881. @staticmethod
  882. def can_inline_constant(t: torch.Tensor) -> bool:
  883. """
  884. True if this is a small constant attr that will be inlined.
  885. """
  886. return len(t.shape) == 1 and t.shape[0] <= 8
  887. def get_attr(self, target, args, kwargs):
  888. # this is a constant
  889. value = getattr_recursive(self.module, target)
  890. if isinstance(value, torch.fx.GraphModule):
  891. return ir.Subgraph(name=target, graph_module=value)
  892. if isinstance(value, torch._C.ScriptObject):
  893. self.torchbind_constants[target] = value
  894. self.constant_reprs[target] = ""
  895. return TorchBindObject(target, value)
  896. if (
  897. config.aot_inductor.use_runtime_constant_folding
  898. or config.always_keep_tensor_constants
  899. or unsupported_output_tensor(value)
  900. ):
  901. return self.add_tensor_constant(value, target)
  902. with no_dispatch():
  903. if value.shape == ():
  904. return Constant(value.item(), value.dtype, value.device)
  905. if self.can_inline_constant(value):
  906. # tensor lowering has constant inlining logic
  907. from .lowering import tensor
  908. return tensor(value.tolist(), dtype=value.dtype, device=value.device)
  909. return self.add_tensor_constant(value, target)
  910. def call_module(self, target, args, kwargs):
  911. raise AssertionError
  912. def call_method(self, target, args, kwargs):
  913. raise AssertionError
  914. def output(self, target, args, kwargs):
  915. result = super().output(target, args, kwargs)
  916. if not isinstance(result, (tuple, list)):
  917. # nested subgraphs can have singleton outputs
  918. result = (result,)
  919. assert isinstance(result, (tuple, list)), type(result)
  920. assert all(
  921. isinstance(
  922. x,
  923. (
  924. TensorBox,
  925. ir.Constant,
  926. type(None),
  927. ir.ConstantBuffer,
  928. sympy.Expr,
  929. sympy.logic.boolalg.Boolean,
  930. int,
  931. ir.EffectfulKernel,
  932. ),
  933. )
  934. for x in result
  935. ), result
  936. fx_node_args = V.graph.current_node.args[0] # type: ignore[arg-type]
  937. if not isinstance(fx_node_args, (tuple, list)):
  938. # nested subgraphs can have singleton outputs
  939. fx_node_args = (fx_node_args,)
  940. result = [ir.ExternKernel.realize_input(x) for x in result]
  941. result_correct_strides = []
  942. assert len(fx_node_args) == len(result)
  943. for r, fx_node in zip(result, fx_node_args):
  944. if not isinstance(r, (ir.TensorBox, ir.BaseView)):
  945. result_correct_strides.append(r)
  946. else:
  947. # AOT Autograd tries to detect stride divergence of inductor from output metadata.
  948. # Here, we try to avoid spurious divergence by matching insignificant strides such as
  949. result_correct_strides.append(
  950. self.try_match_insignificant_strides(
  951. r, fx_node.meta["val"].stride()
  952. )
  953. )
  954. self.graph_outputs = result_correct_strides
  955. value: ir.IRNode
  956. for name, value in self.graph_inputs.items():
  957. assert isinstance(
  958. value, (TensorBox, sympy.Expr)
  959. ), f"Unsupported inductor graph input type: {type(value)}"
  960. if not isinstance(value, TensorBox):
  961. continue
  962. value.realize()
  963. assert isinstance(value, TensorBox)
  964. value = value.data
  965. assert isinstance(value, ir.StorageBox)
  966. value_storage_box = value
  967. value = value.data
  968. if not isinstance(value, InputBuffer) or value.get_name() != name:
  969. # one of our inputs was mutated, need to turn that into a copy
  970. ir.MutationLayoutSHOULDREMOVE.realize_into(
  971. value, self.graph_inputs_original[name]
  972. )
  973. # replace output with mutated input
  974. try:
  975. ind = self.graph_outputs.index(value_storage_box)
  976. self.graph_outputs[ind] = self.graph_inputs_original[name]
  977. except ValueError:
  978. pass
  979. self.finalize()
  980. log.debug(
  981. "Force channels last inputs for %d conv for the current graph with id %d",
  982. self.num_channels_last_conv,
  983. self.graph_id if self.graph_id is not None else -1,
  984. )
  985. def finalize(self):
  986. for buf in self.buffers:
  987. buf.decide_layout()
  988. @contextmanager
  989. def set_current_node(self, node: torch.fx.Node):
  990. old = self.current_node
  991. try:
  992. self.current_node = node
  993. yield
  994. finally:
  995. self.current_node = old
  996. def try_match_insignificant_strides(
  997. self,
  998. tensor,
  999. meta_strides_inp: Tuple[Union[int, torch.SymInt], ...],
  1000. ) -> ir.TensorBox:
  1001. """
  1002. Tries to match the strides of the tensor to those in the meta_strides. Strides of insignificant
  1003. dimensions - size 0 or 1 - will be updated.
  1004. If there are real stride differences (NHWC vs NCHW) then the input will be returned.
  1005. """
  1006. # should have already been realized
  1007. assert torch._inductor.ir.is_storage_and_layout(tensor)
  1008. meta_strides = [
  1009. s.node.expr if isinstance(s, torch.SymInt) else s for s in meta_strides_inp
  1010. ]
  1011. if all(
  1012. self.sizevars.statically_known_equals(s1, s2)
  1013. for s1, s2 in zip(meta_strides, tensor.get_stride())
  1014. ):
  1015. return tensor
  1016. def significant_strides_equal(shape, meta_strides, tensor_strides):
  1017. for dim, s1, s2 in zip(shape, meta_strides, tensor_strides):
  1018. if self.sizevars.statically_known_leq(dim, 1): # type: ignore[arg-type]
  1019. continue
  1020. if not self.sizevars.statically_known_equals(s1, s2):
  1021. return False
  1022. return True
  1023. if not significant_strides_equal(
  1024. tensor.get_size(), meta_strides, tensor.get_stride()
  1025. ):
  1026. return tensor
  1027. storage, old_layout = torch._inductor.ir.as_storage_and_layout(tensor)
  1028. new_stride = list(old_layout.stride)
  1029. for i, s in enumerate(tensor.get_size()):
  1030. if self.sizevars.statically_known_leq(s, 1): # type: ignore[arg-type]
  1031. new_stride[i] = meta_strides[i]
  1032. new_layout = torch._inductor.ir.FixedLayout(
  1033. old_layout.device,
  1034. old_layout.dtype,
  1035. old_layout.size,
  1036. new_stride,
  1037. old_layout.offset,
  1038. )
  1039. return ir.TensorBox(torch._inductor.ir.ReinterpretView(storage, new_layout))
  1040. def run_node(self, n: torch.fx.Node):
  1041. def debug(msg):
  1042. log.debug("lowering %s %s", LazyString(n.format_node), msg)
  1043. buffer_watermark = len(self.buffers)
  1044. origins = {n}
  1045. if n.op == "call_function":
  1046. args, kwargs = self.fetch_args_kwargs_from_env(n)
  1047. origins |= gather_origins(args, kwargs)
  1048. with ir.IRNode.current_origins(origins), self.set_current_node(
  1049. n
  1050. ), V.set_current_node(n):
  1051. if (
  1052. n.op == "call_function"
  1053. and n.target is not operator.getitem
  1054. and fallback_node_due_to_unsupported_type(n)
  1055. ):
  1056. debug("fallback_handler")
  1057. result = fallback_handler(n.target, add_to_fallback_set=False)(
  1058. *args, **kwargs # type: ignore[possibly-undefined]
  1059. )
  1060. elif n.op == "call_function" and n.target in layout_constraints:
  1061. debug("layout_constraints")
  1062. args, kwargs = layout_constraints[n.target](n, *args, **kwargs) # type: ignore[index]
  1063. result = self.call_function(n.target, args, kwargs)
  1064. elif is_magic_method(n.target):
  1065. # TODO: this is sus, it probably should be handled in the
  1066. # lowerings themselves similarly to sym_size/sym-stride
  1067. # https://github.com/pytorch/pytorch/issues/127789
  1068. debug("is_magic_method")
  1069. if isinstance(
  1070. n.meta["val"], (torch.SymInt, torch.SymFloat, torch.SymBool)
  1071. ):
  1072. result = n.meta["val"].node.expr
  1073. else:
  1074. result = super().run_node(n)
  1075. else:
  1076. debug("")
  1077. result = super().run_node(n)
  1078. # require the same stride order for dense outputs,
  1079. # 1. user-land view() will not throw because inductor
  1080. # output different strides than eager
  1081. # long term the solution is to make view() always succeed
  1082. # with infallible strides.
  1083. # 2: as_strided ops, we need make sure its input has same size/stride with
  1084. # eager model to align with eager behavior.
  1085. as_strided_ops = [
  1086. torch.ops.aten.as_strided.default,
  1087. torch.ops.aten.as_strided_.default,
  1088. torch.ops.aten.as_strided_scatter.default,
  1089. torch.ops.aten.resize.default,
  1090. torch.ops.aten.resize_as.default,
  1091. ]
  1092. is_output = any(user.op == "output" for user in n.users)
  1093. is_input_for_as_strided = any(
  1094. user.target in as_strided_ops for user in n.users
  1095. )
  1096. if n.meta.get("inductor_realize_to_strides", False) and isinstance(
  1097. result, TensorBox
  1098. ):
  1099. result.realize()
  1100. strides = n.meta["val"].stride()
  1101. sym_strides = torch._inductor.utils.any_is_symbolic(*strides)
  1102. if (
  1103. not hasattr(result, "get_stride")
  1104. or result.get_stride() != strides
  1105. and not sym_strides
  1106. ):
  1107. stride_order = ir.get_stride_order(strides)
  1108. result = ir.ExternKernel.require_stride_order(result, stride_order)
  1109. if (
  1110. is_output
  1111. and isinstance(result, TensorBox)
  1112. and isinstance(result.data, ir.BaseView)
  1113. ):
  1114. # Realize so that outputs are correctly aliased
  1115. result.realize()
  1116. if (is_output or is_input_for_as_strided) and isinstance(
  1117. n.meta["val"], torch.Tensor
  1118. ):
  1119. strides = n.meta["val"].stride()
  1120. dense = torch._prims_common.is_non_overlapping_and_dense(n.meta["val"])
  1121. unbacked_symbols_in_strides = len(free_unbacked_symbols(strides)) > 0
  1122. # requiring a stride order for a non-dense output wouldn't
  1123. # recreate the same strides, and would fail with view, defer for now.
  1124. if not unbacked_symbols_in_strides and dense and len(strides):
  1125. stride_order = ir.get_stride_order(strides)
  1126. if (
  1127. len(result.get_size()) == 4
  1128. and n in self.nodes_prefer_channels_last
  1129. and n.name not in self.user_visible_outputs
  1130. and not is_input_for_as_strided
  1131. ):
  1132. stride_order = ir.NHWC_STRIDE_ORDER
  1133. allow_padding = (
  1134. n.name not in self.user_visible_outputs
  1135. and not is_input_for_as_strided
  1136. )
  1137. result = ir.ExternKernel.require_stride_order(
  1138. result, stride_order, allow_padding=allow_padding
  1139. )
  1140. # Realize if (1) any user need inputs realized, or (2) there is
  1141. # already too many reads and rematerializing can be bad.
  1142. num_users = len(set(n.users))
  1143. if num_users > 1 and isinstance(result, TensorBox):
  1144. for user in n.users:
  1145. if user.target in needs_realized_inputs:
  1146. result.realize_hint()
  1147. # This inclusion is somewhat controversial (from
  1148. # discussion between Horace, Natalia, and Elias).
  1149. # Currently, it's not very clear why this is helpful.
  1150. # The general idea here is that even though a node may
  1151. # have FlexibleLayout, we still often *treat* it as if
  1152. # it was contiguous. This appears to sometimes result in
  1153. # suboptimal behavior.
  1154. #
  1155. # When we do a better job selecting layout, we should
  1156. # revisit this.
  1157. need_fixed_layout = [
  1158. torch.ops.aten.convolution_backward.default,
  1159. torch.ops.aten.mm.default,
  1160. torch.ops.aten._int_mm.default,
  1161. ]
  1162. need_fixed_channels_last_layout = []
  1163. if not self.layout_opt:
  1164. need_fixed_layout.append(torch.ops.aten.convolution.default)
  1165. if torch._C._has_mkldnn:
  1166. need_fixed_layout += [
  1167. torch.ops.mkldnn._linear_pointwise.default,
  1168. torch.ops.mkldnn._linear_pointwise.binary,
  1169. torch.ops.aten.mkldnn_rnn_layer.default,
  1170. torch.ops.onednn.qlinear_pointwise.default,
  1171. torch.ops.onednn.qlinear_pointwise.tensor,
  1172. torch.ops.onednn.qlinear_pointwise.binary,
  1173. torch.ops.onednn.qlinear_pointwise.binary_tensor,
  1174. ]
  1175. need_fixed_channels_last_layout += [
  1176. torch.ops.mkldnn._convolution_pointwise.default,
  1177. torch.ops.mkldnn._convolution_pointwise.binary,
  1178. torch.ops.mkldnn._convolution_pointwise_.binary,
  1179. torch.ops.mkldnn._convolution_transpose_pointwise.default,
  1180. torch.ops.onednn.qconv2d_pointwise.default,
  1181. torch.ops.onednn.qconv2d_pointwise.binary,
  1182. ]
  1183. if torch._C.has_mkl:
  1184. need_fixed_layout += [torch.ops.mkl._mkl_linear.default]
  1185. if user.target in need_fixed_layout:
  1186. result = ir.ExternKernel.require_stride_order(
  1187. result,
  1188. ir.get_stride_order(n.meta["val"].stride()),
  1189. allow_padding=True,
  1190. )
  1191. if (
  1192. user.target in need_fixed_channels_last_layout
  1193. and n is user.args[0]
  1194. ):
  1195. result = ir.ExternKernel.require_stride_order(
  1196. result,
  1197. ir.get_stride_order(
  1198. make_channels_last_strides_for(n.meta["val"].shape)
  1199. ),
  1200. )
  1201. if user.op == "output":
  1202. if isinstance(result.data.data, (Pointwise, Reduction)):
  1203. result.realize()
  1204. # TODO(jansel): introduce a store vs inline choice
  1205. result.mark_reuse(len(n.users))
  1206. # Realize if the IRNode already has accumulated lots of reads
  1207. if isinstance(result, TensorBox) and result.has_exceeded_max_reads():
  1208. # Prevent excessive accumulation in a computed buffer, when
  1209. # there are multiple branches each with small number of memory
  1210. # reads, but they converge to a user.
  1211. result.realize_hint()
  1212. # Realize if a Pointwise has too much stuff to be inlined.
  1213. # As this may cause RecursionError during Inductor's evaluation.
  1214. if isinstance(result, TensorBox) and isinstance(result.data, StorageBox):
  1215. curr = result.data.data
  1216. if isinstance(curr, Pointwise):
  1217. # Use inner fn as a rough proxy. Good enough.
  1218. if curr.has_large_inner_fn():
  1219. result.realize()
  1220. # This is not complete, but it doesn't have to be: origin_node
  1221. # tracking is best effort. The logic here critically relies on direct
  1222. # TensorBox -> StorageBox denoting a non-view; we don't bother trying
  1223. # to get views to work. Feel free to add any extra cases as needed.
  1224. #
  1225. # Note: we can't YOLO tree_map over this result, because if there are
  1226. # buffers or a view involved, we might not be able to validly assign
  1227. # the origin_node here.
  1228. if isinstance(result, TensorBox) and isinstance(result.data, ir.StorageBox):
  1229. if isinstance(result.data.data, ir.Loops):
  1230. result.data.data.origin_node = n
  1231. elif isinstance(result.data.data, ir.Buffer):
  1232. result.data.data.origin_node = n
  1233. if isinstance(result.data.data, ir.ComputedBuffer) and isinstance(
  1234. result.data.data.data, ir.Loops
  1235. ):
  1236. result.data.data.data.origin_node = n
  1237. # Not really multi-output, can straightforwardly recurse in
  1238. elif (
  1239. isinstance(result.data.data, ir.MultiOutput)
  1240. and not result.data.data.indices
  1241. ):
  1242. if isinstance(result.data.data.inputs[0], ir.Buffer):
  1243. result.data.data.inputs[0].origin_node = n
  1244. self.register_users_of(result)
  1245. new_unbacked_defs = set()
  1246. for i in range(buffer_watermark, len(self.buffers)):
  1247. new_unbacked_defs |= self.buffers[i].get_unbacked_symbol_defs()
  1248. def format_buffers():
  1249. r = []
  1250. for b in self.buffers[buffer_watermark:]:
  1251. r.append(
  1252. f"unbacked_symbol_defs={b.get_unbacked_symbol_defs()} in:\n{b}\n"
  1253. )
  1254. return "***\n".join(r)
  1255. if n.op != "placeholder":
  1256. # Note [Backwards runtime asserts]
  1257. # Backwards poses an interesting problem for deferred runtime
  1258. # asserts. In the easy case, we may solely close over data
  1259. # dependent sized tensors, and there are no binding sites for
  1260. # unbacked SymInts. In this case, we can just drop all the
  1261. # runtime asserts on the floor: no non-placeholder bindings, no
  1262. # problem.
  1263. #
  1264. # However, it is *possible* for a fresh runtime assert to show up
  1265. # between forwards and backwards. Right now, the freezing process
  1266. # that happens when we lower forwards means that we will freeze
  1267. # runtime asserts, and then the moment the backwards lowering
  1268. # process attempts to add a new deferred runtime assert, we will
  1269. # fail. Let's say you remove that assert. Now when we get here,
  1270. # we need to make sure we actually emit these asserts (because we
  1271. # can't emit them in forwards, we already compiled it). So we
  1272. # have to do something here. But we don't want to reemit ALL
  1273. # deferred runtime asserts, we only want to emit the NEW ones.
  1274. # Therefore needing some sort of stratification in the ShapeEnv.
  1275. # This is all doable, it just hasn't been done yet.
  1276. shape_env = V.graph.sizevars.shape_env
  1277. for i0 in new_unbacked_defs:
  1278. ras = self.ras_by_symbol.pop(i0, [])
  1279. # NB: size-like not needed, we won't retrace
  1280. vr = shape_env.var_to_range[i0]
  1281. if not shape_env._default_unspecified_value_range().issubset(vr):
  1282. def convert(s):
  1283. try:
  1284. return int(s)
  1285. except TypeError:
  1286. return None
  1287. if (lower := convert(vr.lower)) is not None:
  1288. self.register_buffer(
  1289. ir.AssertScalar(i0 >= vr.lower, f"{i0} >= {vr.lower}"),
  1290. set_name=True,
  1291. )
  1292. if (upper := convert(vr.upper)) is not None:
  1293. self.register_buffer(
  1294. ir.AssertScalar(i0 <= vr.upper, f"{i0} <= {vr.upper}"),
  1295. set_name=True,
  1296. )
  1297. for ra in ras:
  1298. fvs = free_unbacked_symbols(ra.expr)
  1299. missing = fvs - self.bound_unbacked_symbols
  1300. if missing:
  1301. i1 = sorted(missing, key=lambda x: str(x))[0]
  1302. self.ras_by_symbol.setdefault(i1, []).append(ra)
  1303. else:
  1304. self.register_buffer(
  1305. ir.AssertScalar(ra.expr, f"{ra.expr}"), set_name=True
  1306. )
  1307. self.bound_unbacked_symbols |= new_unbacked_defs
  1308. unbacked_bindings = resolve_unbacked_bindings(
  1309. V.graph.sizevars.shape_env, n.meta.get("unbacked_bindings", {})
  1310. )
  1311. # When we do lowering, it is possible we reallocate unbacked SymInts.
  1312. # So we need to line up the unbacked SymInts when performing the test
  1313. # here
  1314. #
  1315. # In principle, we could permit lowering to introduce MORE unbacked
  1316. # SymInts: as long as all the old unbacked ones are accounted for,
  1317. # it's fine for inductor to introduce extra calls to item()/unbacked()
  1318. # whatever. This actually happens in practice when an unbacked SymInt
  1319. # gets memoized away; naively, when Inductor reprocesses a kernel, it
  1320. # doesn't know that the memo still applies, and ends up allocating a
  1321. # new symbol. However, this is generally a bad thing: we may still
  1322. # end up needing to test equalities on the symbols, and a fresh
  1323. # symbol is likely to hit lots of GuardOnDataDependent errors that
  1324. # we already know facts for.
  1325. renamed_unbacked_bindings = {
  1326. V.fake_mode.shape_env.unbacked_renamings.get(s, s)
  1327. for s in unbacked_bindings.keys()
  1328. }
  1329. assert new_unbacked_defs >= renamed_unbacked_bindings, (
  1330. f"failed {new_unbacked_defs} >= {renamed_unbacked_bindings} (inductor >= fx)\n"
  1331. f"fx node is: {n.format_node()}\n"
  1332. f"new buffers are:\n\n{format_buffers()}"
  1333. )
  1334. return result
  1335. def validate_can_generate_cpp_wrapper(self):
  1336. if config.disable_cpp_codegen:
  1337. raise CppWrapperCodeGenError("C++ codegen is disabled")
  1338. if sys.platform not in ["linux", "darwin"]:
  1339. raise CppWrapperCodeGenError(f"Unsupported platform {sys.platform}")
  1340. for value in self.graph_inputs.values():
  1341. dtype = None
  1342. if isinstance(value, TensorBox):
  1343. dtype = value.get_dtype()
  1344. elif isinstance(
  1345. value, (sympy.Symbol, sympy.Expr, sympy.core.numbers.Integer)
  1346. ):
  1347. dtype = may_get_constant_buffer_dtype(value)
  1348. if not supported_dtype_of_cpp_wrapper(dtype, self.cuda):
  1349. raise CppWrapperCodeGenError(f"Unsupported input dtype {dtype}")
  1350. def init_wrapper_code(self):
  1351. self.cuda = "cuda" in self.device_types
  1352. if self.cpp_wrapper:
  1353. self.validate_can_generate_cpp_wrapper()
  1354. device_types = self.device_types.copy()
  1355. device_types.discard("cpu")
  1356. device_types.discard("meta")
  1357. # TODO(Eikan): Only support mixing cpu and other device now.
  1358. assert len(device_types) <= 1, "Does not support mixing {}".format(
  1359. "+".join(device_types)
  1360. )
  1361. only_cpu = len(device_types) == 0
  1362. device_type = "cpu" if only_cpu else device_types.pop()
  1363. self.device_ops = get_device_op_overrides(device_type)
  1364. wrapper_code_gen_cls = get_wrapper_codegen_for_device(
  1365. device_type, self.cpp_wrapper
  1366. )
  1367. assert wrapper_code_gen_cls is not None, f"Device {device_type} not supported"
  1368. self.wrapper_code = wrapper_code_gen_cls()
  1369. if self.const_module:
  1370. # If we have const module, we could reuse the kernels
  1371. # This could avoid duplication and save time on doing recompilation (if Triton.)
  1372. self.wrapper_code._names_iter = self.const_module.wrapper_code._names_iter
  1373. self.wrapper_code.src_to_kernel = (
  1374. self.const_module.wrapper_code.src_to_kernel
  1375. )
  1376. def codegen_with_cpp_wrapper(self):
  1377. """
  1378. For CPU, the cpp wrapper codegen is done in one pass.
  1379. For GPU, the cpp wrapper codegen is done in two steps: JIT-compile the model with python
  1380. wrapper code and run it to generate autotuned kernel binaries in the first pass; and then
  1381. generate cpp wrapper code and compile it to a dynamic library in the second pass.
  1382. """
  1383. if "cuda" in self.device_types:
  1384. # first pass
  1385. self.cpp_wrapper = False
  1386. # Although triton.store_cubin was set in compile_fx, the backward pass didn't pick
  1387. # that up. In theory it should work by only setting triton.store_cubin to True here,
  1388. # but that will cause a problem when use_runtime_constant_folding is set.
  1389. with config.patch({"triton.store_cubin": True}):
  1390. compiled = self.compile_to_module().call
  1391. def materialize(x):
  1392. if isinstance(x, (torch.SymInt, torch.SymFloat)):
  1393. # Need concrete value to run dynamic shapes and tune the result
  1394. return x.node.hint
  1395. elif isinstance(x, FakeTensor):
  1396. return defake(x)
  1397. else:
  1398. assert isinstance(
  1399. x, torch.Tensor
  1400. ), "Unknown type when creating real inputs" + str(type(x))
  1401. return x
  1402. tracing_context = torch._guards.TracingContext.try_get()
  1403. if tracing_context is not None and not isinstance(
  1404. V.real_inputs, NullHandler
  1405. ):
  1406. if tracing_context.output_strides:
  1407. tracing_context.output_strides.clear()
  1408. params_flat = [
  1409. param
  1410. for param in tracing_context.params_flat # type: ignore[union-attr]
  1411. if param is not None
  1412. ]
  1413. real_inputs = [
  1414. materialize(x) for x in itertools.chain(params_flat, V.real_inputs)
  1415. ]
  1416. else:
  1417. # In the backward pass, V.real_inputs is not set.
  1418. # Generating random inputs based on self.example_inputs sometimes can be problematic,
  1419. # e.g. illegal memory access. A comprehensive fix is to autotune in a separate process.
  1420. real_inputs = [
  1421. materialize(x)
  1422. for x in (
  1423. self.example_inputs
  1424. if isinstance(V.real_inputs, NullHandler)
  1425. else V.real_inputs
  1426. )
  1427. ]
  1428. if self.mutated_inputs:
  1429. from .compile_fx import clone_preserve_strides
  1430. mutated_input_idxs = [
  1431. idx
  1432. for idx, name in enumerate(self.graph_inputs)
  1433. if name in self.mutated_inputs
  1434. and isinstance(real_inputs[idx], torch.Tensor)
  1435. ]
  1436. for idx in mutated_input_idxs:
  1437. # clone mutated Tensor inputs to avoid mutating them in
  1438. # the first pass of the CPP wrapper-based compilation, as
  1439. # this will lead to a side effect on the example inputs:
  1440. # e.g. if torch.compile(f)(x) if called on input-mutating
  1441. # f, the inputs x will be mutated twice in the process:
  1442. # once here, and again when running the compiled model;
  1443. # this will also lead to a numerically incorrect output
  1444. real_inputs[idx] = clone_preserve_strides(real_inputs[idx])
  1445. with torch.utils._python_dispatch._disable_current_modes():
  1446. compiled(real_inputs)
  1447. del real_inputs
  1448. # second pass
  1449. # TODO: reuse self.scheduler from the first pass to speed up the second pass
  1450. self.cpp_wrapper = True
  1451. self.removed_buffers.clear()
  1452. self.inplaced_to_remove.clear()
  1453. V.graph.sizevars.precomputed_replacements.clear()
  1454. V.graph.sizevars.inv_precomputed_replacements.clear()
  1455. return self.codegen()
  1456. else:
  1457. # cpu
  1458. return self.codegen()
  1459. def codegen(self):
  1460. from .scheduler import Scheduler
  1461. self.init_wrapper_code()
  1462. self.scheduler = Scheduler(self.buffers)
  1463. V.debug.draw_orig_fx_graph(self.orig_gm, self.scheduler.nodes)
  1464. self.wrapper_code.push_codegened_graph(self)
  1465. self.scheduler.codegen()
  1466. result = self.wrapper_code.generate(self.is_inference)
  1467. self.wrapper_code.pop_codegened_graph()
  1468. return result
  1469. def codegen_subgraph(self, parent_graph):
  1470. """
  1471. This is a more compact version of the `codegen()` above
  1472. where we codegen this graph as a subgraph of some parent
  1473. graph. The parent graph is passed as an argument: the
  1474. intention is to inline codegening of the subgraph in
  1475. the parent graph's wrapper code (including the generated
  1476. kerenls). The wrapper code is not finalized (via `.generate()`
  1477. call), as this will be done in the parent graph's `codegen()`.
  1478. """
  1479. from .scheduler import Scheduler
  1480. self.wrapper_code = parent_graph.wrapper_code
  1481. self.device_ops = parent_graph.device_ops
  1482. self.cpp_wrapper = parent_graph.cpp_wrapper
  1483. self.scheduler = Scheduler(self.buffers)
  1484. self.scheduler.codegen()
  1485. def count_bytes(self):
  1486. total_bytes = 0
  1487. node_counts = []
  1488. node_runtimes = []
  1489. for node in self.scheduler.nodes:
  1490. num_bytes = node.get_read_write_buffers_sizes()
  1491. total_bytes += num_bytes
  1492. node_counts.append((node, num_bytes // 4))
  1493. node_runtimes.append((node, node.get_estimated_runtime()))
  1494. return total_bytes, node_counts, node_runtimes
  1495. @dynamo_timed(phase_name="code_gen", fwd_only=False)
  1496. def compile_to_module(self):
  1497. from .codecache import PyCodeCache
  1498. code, linemap = (
  1499. self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
  1500. )
  1501. output_code_log.debug("Output code: \n%s", code)
  1502. try:
  1503. linemap = [(line_no, node.stack_trace) for line_no, node in linemap]
  1504. key, path = PyCodeCache.write(code)
  1505. except Exception:
  1506. trace_structured(
  1507. "inductor_output_code",
  1508. # Just omit the filename, I still want the code though!
  1509. payload_fn=lambda: code,
  1510. )
  1511. raise
  1512. else:
  1513. trace_structured(
  1514. "inductor_output_code",
  1515. lambda: {"filename": path},
  1516. payload_fn=lambda: code,
  1517. )
  1518. mod = PyCodeCache.load_by_key_path(
  1519. key,
  1520. path,
  1521. linemap=linemap,
  1522. attrs={**self.constants, **self.torchbind_constants},
  1523. )
  1524. self.cache_key = key
  1525. self.cache_path = path
  1526. self.cache_linemap = linemap
  1527. # Logged twice as per https://github.com/pytorch/pytorch/pull/99038#discussion_r1167826029
  1528. # TODO. Revisit this once the logging API is more mature
  1529. assert mod.__file__ is not None
  1530. log_module_code(mod.__file__)
  1531. log.debug("Output code written to: %s", mod.__file__)
  1532. output_code_log.info("Output code written to: %s", mod.__file__)
  1533. if config.benchmark_kernel:
  1534. print(f"Compiled module path: {mod.__file__}", file=sys.stderr)
  1535. V.debug.output_code(mod.__file__)
  1536. V.debug.copy(os.path.splitext(mod.__file__)[0] + ".debug")
  1537. return mod
  1538. def compile_to_fn(self):
  1539. if self.aot_mode:
  1540. from .codecache import AotCodeCompiler
  1541. assert self.cpp_wrapper, "AOT mode only supports C++ wrapper"
  1542. code, linemap = self.codegen_with_cpp_wrapper()
  1543. output_code_log.debug("Output code: \n%s", code)
  1544. serialized_extern_kernel_nodes = None
  1545. if (
  1546. config.is_fbcode()
  1547. and self.extern_kernel_nodes
  1548. and self.extern_node_serializer
  1549. ):
  1550. serialized_extern_kernel_nodes = self.extern_node_serializer(
  1551. self.extern_kernel_nodes
  1552. )
  1553. output_code_log.debug(
  1554. "Serialized Extern Kernel Nodes: \n%s",
  1555. serialized_extern_kernel_nodes,
  1556. )
  1557. # Directly return the file path with the compiled code
  1558. return AotCodeCompiler.compile(
  1559. self, code, serialized_extern_kernel_nodes, cuda=self.cuda
  1560. )
  1561. else:
  1562. return self.compile_to_module().call
  1563. def get_output_names(self):
  1564. return [
  1565. node.get_name()
  1566. for node in self.graph_outputs
  1567. if not isinstance(node, ir.NoneAsConstantBuffer)
  1568. and not isinstance(node, ir.ShapeAsConstantBuffer)
  1569. ]
  1570. def is_unspec_arg(self, name: str):
  1571. # dynamo wraps unspec variable as 0d CPU tensor,
  1572. # need to convert to scalar during codegen (triton only)
  1573. return (
  1574. name in self.graph_inputs.keys()
  1575. and self.graph_inputs[name].get_numel() == 1
  1576. and self.graph_inputs[name].get_device().type == "cpu"
  1577. )