cpp.py 162 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358335933603361336233633364336533663367336833693370337133723373337433753376337733783379338033813382338333843385338633873388338933903391339233933394339533963397339833993400340134023403340434053406340734083409341034113412341334143415341634173418341934203421342234233424342534263427342834293430343134323433343434353436343734383439344034413442344334443445344634473448344934503451345234533454345534563457345834593460346134623463346434653466346734683469347034713472347334743475347634773478347934803481348234833484348534863487348834893490349134923493349434953496349734983499350035013502350335043505350635073508350935103511351235133514351535163517351835193520352135223523352435253526352735283529353035313532353335343535353635373538353935403541354235433544354535463547354835493550355135523553355435553556355735583559356035613562356335643565356635673568356935703571357235733574357535763577357835793580358135823583358435853586358735883589359035913592359335943595359635973598359936003601360236033604360536063607360836093610361136123613361436153616361736183619362036213622362336243625362636273628362936303631363236333634363536363637363836393640364136423643364436453646364736483649365036513652365336543655365636573658365936603661366236633664366536663667366836693670367136723673367436753676367736783679368036813682368336843685368636873688368936903691369236933694369536963697369836993700370137023703370437053706370737083709371037113712371337143715371637173718371937203721372237233724372537263727372837293730373137323733373437353736373737383739374037413742374337443745374637473748374937503751375237533754375537563757375837593760376137623763376437653766376737683769377037713772377337743775377637773778377937803781378237833784378537863787378837893790379137923793379437953796379737983799380038013802380338043805380638073808380938103811381238133814381538163817381838193820382138223823382438253826382738283829383038313832383338343835383638373838383938403841384238433844384538463847384838493850385138523853385438553856385738583859386038613862386338643865386638673868386938703871387238733874387538763877387838793880388138823883388438853886388738883889389038913892389338943895389638973898389939003901390239033904390539063907390839093910391139123913391439153916391739183919392039213922392339243925392639273928392939303931393239333934393539363937393839393940394139423943394439453946394739483949395039513952395339543955395639573958395939603961396239633964396539663967396839693970397139723973397439753976397739783979398039813982398339843985398639873988398939903991399239933994399539963997399839994000400140024003400440054006400740084009401040114012401340144015401640174018401940204021402240234024402540264027402840294030403140324033403440354036403740384039404040414042404340444045404640474048404940504051405240534054405540564057405840594060406140624063406440654066406740684069407040714072407340744075407640774078407940804081408240834084408540864087408840894090409140924093409440954096409740984099410041014102410341044105410641074108410941104111411241134114411541164117411841194120412141224123412441254126412741284129413041314132413341344135413641374138413941404141414241434144414541464147414841494150415141524153415441554156415741584159416041614162416341644165416641674168416941704171417241734174417541764177417841794180418141824183418441854186418741884189419041914192419341944195419641974198419942004201420242034204420542064207420842094210421142124213421442154216421742184219422042214222422342244225422642274228422942304231423242334234423542364237423842394240424142424243424442454246424742484249425042514252425342544255425642574258425942604261426242634264426542664267426842694270
  1. # mypy: allow-untyped-defs
  2. import contextlib
  3. import dataclasses
  4. import functools
  5. import itertools
  6. import logging
  7. import math
  8. import re
  9. import sys
  10. from copy import copy, deepcopy
  11. from enum import Enum
  12. from typing import Any, cast, Dict, List, Optional, Sequence, Set, Tuple, Union
  13. import sympy
  14. import torch
  15. import torch.fx
  16. from torch._inductor import dependencies
  17. from torch._prims_common import is_float_dtype
  18. from torch.utils import _pytree as pytree
  19. from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing
  20. from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
  21. from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
  22. from ..._dynamo.utils import counters
  23. from .. import codecache, config, ir, metrics
  24. from ..codegen.wrapper import WrapperCodeGen
  25. from ..optimize_indexing import range_expressable_in_32_bits
  26. from ..scheduler import (
  27. BaseSchedulerNode,
  28. BaseScheduling,
  29. ForeachKernelSchedulerNode,
  30. FusedSchedulerNode,
  31. Scheduler,
  32. SchedulerNode,
  33. )
  34. from ..utils import (
  35. cache_on_self,
  36. get_bounds_index_expr,
  37. get_fused_kernel_name,
  38. is_welford_reduction,
  39. parallel_num_threads,
  40. Placeholder,
  41. sympy_index_symbol,
  42. sympy_index_symbol_with_prefix,
  43. sympy_product,
  44. sympy_subs,
  45. )
  46. from ..virtualized import NullKernelHandler, ops, OpsValue, V
  47. from .common import (
  48. BracesBuffer,
  49. CppWrapperKernelArgs,
  50. CSE,
  51. CSEVariable,
  52. DataTypePropagation,
  53. DeferredLine,
  54. DTYPE_TO_COMPUTATION_DTYPE,
  55. IndentedBuffer,
  56. Kernel,
  57. KernelArgs,
  58. OpOverrides,
  59. OptimizationContext,
  60. )
  61. from .cpp_utils import cexpr, cexpr_index, DTYPE_TO_CPP, INDEX_TYPE, value_to_cpp
  62. schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
  63. NATIVE_OMP_RTYPES = {"+", "*", "^", "||", "min", "max"}
  64. RTYPE_TO_CPP = {
  65. "sum": "+",
  66. "prod": "*",
  67. "xor_sum": "^",
  68. "min": "min",
  69. "max": "max",
  70. "argmin": "argmin",
  71. "argmax": "argmax",
  72. "any": "||",
  73. "welford_reduce": "welford",
  74. "welford_combine": "welford",
  75. }
  76. VECTORIZABLE_RTYPES = {
  77. "max",
  78. "min",
  79. "sum",
  80. "prod",
  81. "xor_sum",
  82. "welford_reduce",
  83. "welford_combine",
  84. }
  85. PYTHON_TO_CPP = {
  86. "Tensor": "at::Tensor",
  87. "int": "long",
  88. "float": "double",
  89. "bool": "bool",
  90. "str": "std::string",
  91. "ScalarType": "c10::ScalarType",
  92. "MemoryFormat": "at::MemoryFormat",
  93. "Layout": "at::Layout",
  94. "Device": "at::Device",
  95. "number": "at::Scalar",
  96. }
  97. CONTAINER_PYTHON_TO_CPP = {
  98. "List": "std::vector",
  99. "Optional": "c10::optional",
  100. }
  101. DTYPE_LOWP_FP = [
  102. torch.bfloat16,
  103. torch.float16,
  104. ]
  105. BIN_CMP_OPS = ["eq", "ne", "le", "ge", "lt", "gt"]
  106. def reduction_init(reduction_type, dtype):
  107. if dtype in DTYPE_LOWP_FP:
  108. # Since load promotes all half-precision inputs to float, the initial
  109. # constant for reduction must be promoted as well
  110. dtype = torch.float32
  111. if reduction_type in ("xor_sum", "sum", "any"):
  112. return 0
  113. if reduction_type == "prod":
  114. return 1
  115. if reduction_type in {"max", "argmax"}:
  116. return (
  117. f"-std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()"
  118. if is_float_dtype(dtype)
  119. else f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::min()"
  120. )
  121. if reduction_type in {"min", "argmin"}:
  122. return (
  123. f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()"
  124. if is_float_dtype(dtype)
  125. else f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::max()"
  126. )
  127. if is_welford_reduction(reduction_type):
  128. return f"Welford<{DTYPE_TO_CPP[dtype]}>()"
  129. raise AssertionError(reduction_type)
  130. def reduction_acc_type(reduction_type, dtype):
  131. assert reduction_type not in {"argmin", "argmax"}
  132. scalar_type = DTYPE_TO_CPP[DTYPE_TO_COMPUTATION_DTYPE[dtype]]
  133. if is_welford_reduction(reduction_type):
  134. return f"Welford<{scalar_type}>"
  135. return scalar_type
  136. def reduction_combine(reduction_type, var, next_value):
  137. if reduction_type == "sum":
  138. return f"{var} + {next_value}"
  139. if reduction_type == "prod":
  140. return f"{var} * {next_value}"
  141. if reduction_type == "xor_sum":
  142. return f"{var} ^ {next_value}"
  143. if reduction_type == "any":
  144. return f"{var} || {next_value}"
  145. if reduction_type in ("min", "max"):
  146. return f"{reduction_type}_propagate_nan({var}, {next_value})"
  147. if reduction_type == "welford_reduce":
  148. return f"welford_combine({var}, {next_value})"
  149. if reduction_type == "welford_combine":
  150. if isinstance(next_value, tuple):
  151. mean, m2, weight = next_value
  152. else:
  153. mean, m2, weight = reduction_project(reduction_type, next_value)
  154. return f"welford_combine({var}, {{{mean}, {m2}, {weight}}})"
  155. raise AssertionError(reduction_type)
  156. def reduction_project(reduction_type, acc):
  157. if is_welford_reduction(reduction_type):
  158. return f"{acc}.mean", f"{acc}.m2", f"{acc}.weight"
  159. elif reduction_type in {"argmin", "argmax"}:
  160. return f"{acc}.index"
  161. return acc
  162. def is_to_lowp_dtype(expr):
  163. to_exprs = ["convert<half>", "convert<bfloat16>"]
  164. return any(to_expr in expr for to_expr in to_exprs)
  165. def get_lowp_to_fp32_expr(lowp_var, kernel):
  166. if isinstance(kernel, CppVecKernel):
  167. return f"at::vec::convert<float>({lowp_var})"
  168. else:
  169. assert isinstance(kernel, CppKernel)
  170. return f"c10::convert<float>({lowp_var})"
  171. index_value_name_counter = 1
  172. def argmax_argmin_prefix(reduction_type, src_dtype, tmpvar):
  173. global index_value_name_counter
  174. num_threads = (
  175. "max_threads" if config.cpp.dynamic_threads else parallel_num_threads()
  176. )
  177. struct_name = f"IndexValue_{index_value_name_counter}"
  178. index_value_name_counter += 1
  179. # A small annoyance, due to it being a little cumbersome to just throw {} into strings
  180. prefix = [
  181. f"struct {struct_name} {{size_t index; {DTYPE_TO_CPP[src_dtype]} value;}};",
  182. f"{struct_name} {tmpvar}{{0, {reduction_init(reduction_type, src_dtype)}}};",
  183. ]
  184. local_init = [
  185. f"{struct_name} {tmpvar}_local{{0, {reduction_init(reduction_type, src_dtype)}}};",
  186. ]
  187. tmpvar_per_thd = f"{tmpvar}_arr[{num_threads}]"
  188. parallel_prefix = [
  189. f"{struct_name} {tmpvar_per_thd};",
  190. ]
  191. return prefix, parallel_prefix, local_init
  192. @functools.lru_cache
  193. def stride_at(index: sympy.Expr, var: sympy.Symbol):
  194. replacement = {var: var + 1}
  195. new_index = sympy_subs(index, replacement) # type: ignore[arg-type]
  196. return sympy.simplify(new_index - index)
  197. @functools.lru_cache
  198. def simplify_index_in_vec_range(index: sympy.Expr, var: sympy.Expr, vec_length: int):
  199. """
  200. Simplifies the index expression within the range of a vectorized loop.
  201. Given a vectorized loop variable `var` in the range of a loop with `vec_length`,
  202. this function transforms the `index` into an equivalent form. It handles
  203. simplifications for cases where `var` can be expressed as `vec_length * a + b`,
  204. where `b` ranges from 0 to `vec_length - 1`. The function reduces occurrences
  205. of `FloorDiv` and `ModularIndexing` in the `index` with best-effort optimizations.
  206. NOTE:
  207. The simplified index expression is intended for analysis purposes only, not
  208. for code generation. It replaces `FloorDiv` and `ModularIndexing` with free variables
  209. which are not dependent on the loop variable `var` in the vectorized range. Check
  210. https://github.com/pytorch/pytorch/pull/117221#discussion_r1449746217 for more details.
  211. Examples:
  212. 1. If `var` is `x3` and `vec_length` is 16, and `x3 = 16*a + b`, then
  213. `FloorDiv(x3, div)` or `ModularIndexing(x3, div, mod)` becomes a free variable
  214. when `div` is divisible by 16.
  215. 2. `ModularIndexing(x3, 1, mod)` can be simplified to `x3 + c` where `c` is a free
  216. variable when `mod` is divisible by 16.
  217. """
  218. div_freevar_id = 0
  219. mod_freevar_id = 0
  220. def visit_indexing_div(divisor):
  221. nonlocal div_freevar_id
  222. result = FloorDiv(var, divisor)
  223. if sympy.gcd(divisor, vec_length) == vec_length:
  224. result = sympy.Symbol(f"{var}_div_c{div_freevar_id}")
  225. div_freevar_id += 1
  226. return result
  227. def visit_modular_indexing(divisor, modulus):
  228. nonlocal mod_freevar_id
  229. result = ModularIndexing(var, divisor, modulus)
  230. if sympy.gcd(divisor, vec_length) == vec_length:
  231. result = sympy.Symbol(f"{var}_mod_c{mod_freevar_id}")
  232. mod_freevar_id += 1
  233. elif divisor == 1 and sympy.gcd(modulus, vec_length) == vec_length:
  234. result = var + sympy.Symbol(f"{var}_mod_c{mod_freevar_id}")
  235. mod_freevar_id += 1
  236. return result
  237. original_index = index
  238. div = sympy.Wild("divisor", integer=True)
  239. if index.has(FloorDiv):
  240. index = index.replace(FloorDiv(var, div), visit_indexing_div)
  241. mod = sympy.Wild("modulus", integer=True)
  242. if index.has(ModularIndexing):
  243. index = index.replace(ModularIndexing(var, div, mod), visit_modular_indexing)
  244. index = sympy.simplify(index)
  245. if index != original_index:
  246. return simplify_index_in_vec_range(index, var, vec_length)
  247. return index
  248. @functools.lru_cache
  249. def stride_at_vec_range(index: sympy.Expr, var: sympy.Symbol, vec_length: int):
  250. index_vec_simplified = simplify_index_in_vec_range(index, var, vec_length)
  251. return stride_at(index_vec_simplified, var)
  252. class OuterLoopFusedSchedulerNode(FusedSchedulerNode):
  253. @classmethod
  254. def fuse( # type: ignore[override]
  255. cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode, outer_loop_fusion_depth
  256. ):
  257. assert node1.scheduler is node2.scheduler
  258. assert all(
  259. type(node)
  260. in (
  261. OuterLoopFusedSchedulerNode,
  262. SchedulerNode,
  263. FusedSchedulerNode,
  264. )
  265. for node in (node1, node2)
  266. )
  267. if any(type(node) is OuterLoopFusedSchedulerNode for node in (node1, node2)):
  268. return cls(
  269. node1.scheduler,
  270. (
  271. list(node1.get_outer_nodes())
  272. if type(node1) is OuterLoopFusedSchedulerNode
  273. else [
  274. node1,
  275. ]
  276. )
  277. + (
  278. list(node2.get_outer_nodes())
  279. if type(node2) is OuterLoopFusedSchedulerNode
  280. else [
  281. node2,
  282. ]
  283. ),
  284. outer_loop_fusion_depth,
  285. )
  286. else:
  287. return cls(node1.scheduler, [node1, node2], outer_loop_fusion_depth) # type: ignore[list-item]
  288. def __init__(
  289. self,
  290. scheduler: "Scheduler",
  291. outer_fused_nodes: List[Union[FusedSchedulerNode, SchedulerNode]],
  292. outer_loop_fusion_depth,
  293. ):
  294. self.outer_fused_nodes: List[
  295. Union[FusedSchedulerNode, SchedulerNode]
  296. ] = outer_fused_nodes
  297. self.outer_loop_fusion_depth = outer_loop_fusion_depth
  298. flatten_snodes = []
  299. for _node in self.outer_fused_nodes:
  300. assert isinstance(_node, (SchedulerNode, FusedSchedulerNode))
  301. flatten_snodes.extend(list(_node.get_nodes()))
  302. super().__init__(scheduler, flatten_snodes) # type: ignore[arg-type]
  303. def get_outer_nodes(self):
  304. return self.outer_fused_nodes
  305. def check_outer_fusion_loop_level_attr(
  306. self, cpp_kernel_proxy_list, outer_loop_fusion_depth
  307. ):
  308. # This function ensures that the same tiling split is applied at each loop level within the outer loop fusion depth.
  309. # In the fusion stage, we only examine nodes with same vars and reduce.
  310. # However, for nodes with same vars and reduce, the loops may still have different tile splits.
  311. # For example (test_expr_vec_non_contiguous in test_cpu_repro.py):
  312. # * buf0 tiling along the 2nd loop level, buf1 tiling along the 3rd loop level.
  313. # If the check failed, we should fall back to standard loop codegen.
  314. def _inner(
  315. left_loop_level: LoopLevel,
  316. right_loop_level: LoopLevel,
  317. loop_fusion_depth: int,
  318. ) -> bool:
  319. # Check if same loop level attr
  320. outer_loops_attr_compare_list = [
  321. "var",
  322. "size",
  323. "offset",
  324. "steps",
  325. ]
  326. if not (
  327. all(
  328. getattr(left_loop_level, attr_compare)
  329. == getattr(right_loop_level, attr_compare)
  330. for attr_compare in outer_loops_attr_compare_list
  331. )
  332. ):
  333. return False
  334. assert loop_fusion_depth >= 1
  335. if (loop_fusion_depth := loop_fusion_depth - 1) > 0:
  336. # If the next loop level is expected to undergo outer loop fusion,
  337. # there should be no kernel present at the current loop level.
  338. assert (
  339. left_loop_level.kernel is None and right_loop_level.kernel is None
  340. )
  341. # Check next loop level attr
  342. if any(
  343. # Assume no main/tail loop split at any outer loop fusion depth
  344. # Given no clear performance benefit for this complex case
  345. len(loop_level.inner) != 1
  346. for loop_level in [left_loop_level, right_loop_level]
  347. ) or not _inner(
  348. left_loop_level.inner[0],
  349. right_loop_level.inner[0],
  350. loop_fusion_depth,
  351. ):
  352. return False
  353. return True
  354. for idx in range(len(cpp_kernel_proxy_list) - 1):
  355. left_loop_nest = cpp_kernel_proxy_list[idx].loop_nest
  356. right_loop_nest = cpp_kernel_proxy_list[idx + 1].loop_nest
  357. if any(
  358. # Assume no main/tail loop split at any outer loop fusion depth
  359. len(loop_nest.root) != 1
  360. for loop_nest in [left_loop_nest, right_loop_nest]
  361. ) or not _inner(
  362. left_loop_nest.root[0], right_loop_nest.root[0], outer_loop_fusion_depth
  363. ):
  364. return False
  365. return True
  366. def merge_outer_fusion_kernels(
  367. self,
  368. cpp_kernel_proxy_list,
  369. ):
  370. loop_nest_list: List[LoopNestWithSplit] = [
  371. kernel.loop_nest for kernel in cpp_kernel_proxy_list
  372. ]
  373. metrics.cpp_outer_loop_fused_inner_counts.append(len(loop_nest_list))
  374. kernel_group = cpp_kernel_proxy_list[0].kernel_group
  375. def _merge_outer_fusion_loop_levels(
  376. loop_level_nested_list: List[List["LoopLevel"]],
  377. outer_loop_fusion_depth,
  378. ):
  379. assert outer_loop_fusion_depth >= 1
  380. # Assume no main/tail loop split at any outer loop fusion depth
  381. assert all(
  382. len(loop_level_list) == 1 for loop_level_list in loop_level_nested_list
  383. )
  384. if (outer_loop_fusion_depth := outer_loop_fusion_depth - 1) >= 1:
  385. # Further merge the next loop level
  386. next_loop_level_nested_list = [
  387. loop_level_list[0].inner
  388. for loop_level_list in loop_level_nested_list
  389. ]
  390. _merge_outer_fusion_loop_levels(
  391. next_loop_level_nested_list,
  392. outer_loop_fusion_depth,
  393. )
  394. else:
  395. outer_loop_fused_kernel = OuterLoopFusedKernel(kernel_group)
  396. loop_level_of_first_kernel = loop_level_nested_list[0][0]
  397. for kernel_idx in range(len(loop_level_nested_list)):
  398. outer_loop_fused_kernel.inner.append(
  399. deepcopy(loop_level_nested_list[kernel_idx][0]),
  400. )
  401. loop_level_of_first_kernel.inner = []
  402. loop_level_of_first_kernel.kernel = outer_loop_fused_kernel
  403. # Merge the List[LoopNestWithSplit] from cpp_kernel_proxy_list
  404. # into cpp_kernel_proxy_list[0].loop_nest
  405. _merge_outer_fusion_loop_levels(
  406. [_loop_nest.root for _loop_nest in loop_nest_list], # type: ignore[misc]
  407. self.outer_loop_fusion_depth,
  408. )
  409. return cpp_kernel_proxy_list[0]
  410. class RecordOptimizationContext:
  411. def __init__(self, func_name: str = ""):
  412. self.func_name = func_name
  413. self.current_node: Optional[torch.fx.Node] = None
  414. self.opt_ctx: Optional[OptimizationContext] = None
  415. def __enter__(self):
  416. assert V.interpreter
  417. assert V.interpreter.current_node
  418. self.current_node = V.interpreter.current_node
  419. assert self.current_node is not None
  420. if OptimizationContext.key in self.current_node.meta:
  421. self.opt_ctx = self.current_node.meta[OptimizationContext.key]
  422. else:
  423. self.opt_ctx = OptimizationContext()
  424. assert self.opt_ctx is not None
  425. self.opt_ctx.ops_name = self.func_name
  426. return self
  427. def __exit__(self, exc_type, exc_val, exc_tb):
  428. assert self.current_node
  429. assert self.opt_ctx
  430. self.current_node.meta[OptimizationContext.key] = self.opt_ctx
  431. def get_opt_ctx(self):
  432. return self.opt_ctx
  433. def get_fx_node(self):
  434. assert self.current_node
  435. return self.current_node
  436. def get_opt_ctx(node: torch.fx.Node) -> OptimizationContext:
  437. return node.meta.get(OptimizationContext.key, None)
  438. def get_current_node_opt_ctx() -> OptimizationContext:
  439. assert V.interpreter.current_node
  440. return get_opt_ctx(V.interpreter.current_node)
  441. class CppCSEVariable(CSEVariable):
  442. def __init__(self, name, bounds: ValueRanges[Any]):
  443. super().__init__(name, bounds)
  444. self.is_vec = False
  445. self.dtype: Optional[torch.dtype] = None
  446. self.dependent_itervars: Set[sympy.Symbol] = set()
  447. def __repr__(self):
  448. return (
  449. f"CppCSEVariable(name: {self.name}, bounds: {self.bounds}, is_vec: {self.is_vec}, dtype: {self.dtype}, "
  450. f"dependent_itervars: {self.dependent_itervars})"
  451. )
  452. def update_on_args(self, name, args, kwargs):
  453. if name == "load":
  454. # args[1] is index
  455. self._set_dependent_itervars(args[1])
  456. else:
  457. # propagate relevant itervars and is_vec from args
  458. self.dependent_itervars.update(
  459. *[
  460. arg.dependent_itervars
  461. for arg in args
  462. if isinstance(arg, CppCSEVariable)
  463. ]
  464. )
  465. if name == "index_expr":
  466. self._set_dependent_itervars(args[0])
  467. if any(arg.is_vec for arg in args if isinstance(arg, CppCSEVariable)):
  468. self.is_vec = True
  469. # NOTE [dtype of CppCSEVariable]
  470. # Deciding dtype according to the current optimization context is not
  471. # always accurate since the dtypes are initialized during dtype propagation
  472. # at the beginning of the codegen. It is possible that some ops are invoked
  473. # during the codegen of the current op and take different dtypes from the
  474. # current op.
  475. # TODO(jgong5): A more accurate way of deciding the dtype of the variables is to
  476. # propagate the dtypes here inside `update_on_args`.
  477. if (
  478. hasattr(V.interpreter, "current_node")
  479. and get_current_node_opt_ctx() is not None
  480. ):
  481. self.dtype = get_current_node_opt_ctx().dtype
  482. if name in BIN_CMP_OPS:
  483. self.dtype = torch.bool
  484. def _set_dependent_itervars(self, index: sympy.Expr):
  485. """
  486. Set the relevant itervars for this variable based on the `index` expression.
  487. This includes the itervars directly used in the `index` as well as relevant itervars
  488. of other cse variables used in the `index`.
  489. """
  490. for s in index.free_symbols:
  491. if s in V.kernel.itervars:
  492. self.dependent_itervars.add(s) # type: ignore[arg-type]
  493. elif s.name in V.kernel.cse.varname_map: # type: ignore[attr-defined]
  494. self.dependent_itervars.update(
  495. V.kernel.cse.varname_map[s.name].dependent_itervars # type: ignore[attr-defined]
  496. )
  497. def depends_on(self, itervar: sympy.Symbol):
  498. return itervar in self.dependent_itervars
  499. class CppOverrides(OpOverrides):
  500. """Map element-wise ops to C++"""
  501. @staticmethod
  502. def add(a, b):
  503. return f"decltype({a})({a} + {b})"
  504. @staticmethod
  505. def sub(a, b):
  506. return f"decltype({a})({a} - {b})"
  507. @staticmethod
  508. def mul(a, b):
  509. return f"decltype({a})({a} * {b})"
  510. @staticmethod
  511. def to_dtype(x, dtype, src_dtype=None):
  512. assert dtype in DTYPE_TO_CPP, f"{dtype} missing from {__name__}.DTYPE_TO_CPP"
  513. return f"c10::convert<{DTYPE_TO_CPP[dtype]}>({x})"
  514. @staticmethod
  515. def to_dtype_bitcast(x, dtype, src_dtype):
  516. assert dtype in DTYPE_TO_CPP, f"{dtype} missing from {__name__}.DTYPE_TO_CPP"
  517. if src_dtype in (torch.float16, torch.bfloat16):
  518. # c10::bit_cast requires the source and target have the bitwidth.
  519. # Because the input tensor's dtype could be promoted, e.g. from float16 to
  520. # float, we have to cast the tensor to its original source dtype before
  521. # invoking bit_cast. We also need to convert the bit-casted tensor
  522. # back to float to make sure we keep using higher precision values
  523. # for the rest of the computation.
  524. cast_x = f"c10::convert<{DTYPE_TO_CPP[src_dtype]}>({x})"
  525. cast_x = f"c10::bit_cast<{DTYPE_TO_CPP[dtype]}>({cast_x})"
  526. return f"c10::convert<{DTYPE_TO_CPP[torch.float32]}>({cast_x})"
  527. else:
  528. return f"c10::bit_cast<{DTYPE_TO_CPP[dtype]}>({x})"
  529. @staticmethod
  530. def abs(x):
  531. return f"std::abs({x})"
  532. @staticmethod
  533. def sin(x):
  534. return f"std::sin({x})"
  535. @staticmethod
  536. def cos(x):
  537. return f"std::cos({x})"
  538. @staticmethod
  539. def neg(x):
  540. return f"decltype({x})(-{x})"
  541. @staticmethod
  542. def exp(x):
  543. # return f"Sleef_expf_u10({x})"
  544. return f"std::exp({x})"
  545. @staticmethod
  546. def exp2(x):
  547. return f"std::exp2({x})"
  548. @staticmethod
  549. def expm1(x):
  550. return f"std::expm1({x})"
  551. @staticmethod
  552. def erf(x):
  553. return f"std::erf({x})"
  554. @staticmethod
  555. def erfc(x):
  556. return f"std::erfc({x})"
  557. @staticmethod
  558. def erfinv(x):
  559. return f"calc_erfinv({x})"
  560. @staticmethod
  561. def sqrt(x):
  562. return f"std::sqrt({x})"
  563. @staticmethod
  564. def rsqrt(x):
  565. return f"1 / std::sqrt({x})"
  566. @staticmethod
  567. def log1p(x):
  568. bug = config.cpp.inject_log1p_bug_TESTING_ONLY
  569. if bug == "accuracy":
  570. return f"{x} + decltype({x})(1)"
  571. elif bug is None:
  572. return f"std::log1p({x})"
  573. else:
  574. raise AssertionError(
  575. f"unrecognized config cpp.inject_log1p_bug_TESTING_ONLY = {bug!r}"
  576. )
  577. @staticmethod
  578. def tan(x):
  579. return f"std::tan({x})"
  580. @staticmethod
  581. def tanh(x):
  582. return f"std::tanh({x})"
  583. @staticmethod
  584. def signbit(x):
  585. return f"std::signbit({x})"
  586. @staticmethod
  587. def pow(a, b):
  588. return f"std::pow({a}, {b})"
  589. @staticmethod
  590. def log(x):
  591. return f"std::log({x})"
  592. @staticmethod
  593. def round(x):
  594. return f"std::nearbyint({x})"
  595. @staticmethod
  596. def floor(x):
  597. return f"std::floor({x})"
  598. @staticmethod
  599. def floordiv(a, b):
  600. # a and b are integer type
  601. quot = f"{a} / {b}"
  602. rem = f"{a} % {b}"
  603. return f"(({a} < 0) != ({b} < 0) ? ({rem} != 0 ? {quot} - 1 : {quot}) : {quot})"
  604. @staticmethod
  605. def ceil(x):
  606. return f"std::ceil({x})"
  607. @staticmethod
  608. def trunc(x):
  609. return f"std::trunc({x})"
  610. @staticmethod
  611. def truncdiv(a, b):
  612. # a and b are integer type
  613. return f"{a} / {b}"
  614. @staticmethod
  615. def fmod(a, b):
  616. return f"std::fmod({a}, {b})"
  617. @staticmethod
  618. def isinf(x):
  619. return f"std::isinf({x})"
  620. @staticmethod
  621. def isnan(x):
  622. return f"std::isnan({x})"
  623. @staticmethod
  624. def lgamma(x):
  625. return f"std::lgamma({x})"
  626. @staticmethod
  627. def acos(x):
  628. return f"std::acos({x})"
  629. @staticmethod
  630. def acosh(x):
  631. return f"std::acosh({x})"
  632. @staticmethod
  633. def cosh(x):
  634. return f"std::cosh({x})"
  635. @staticmethod
  636. def sinh(x):
  637. return f"std::sinh({x})"
  638. @staticmethod
  639. def asin(x):
  640. return f"std::asin({x})"
  641. @staticmethod
  642. def asinh(x):
  643. return f"std::asinh({x})"
  644. @staticmethod
  645. def atan2(x, y):
  646. return f"std::atan2({x}, {y})"
  647. @staticmethod
  648. def atan(x):
  649. return f"std::atan({x})"
  650. @staticmethod
  651. def atanh(x):
  652. return f"std::atanh({x})"
  653. @staticmethod
  654. def copysign(x, y):
  655. return f"std::copysign({x}, {y})"
  656. @staticmethod
  657. def frexp(x):
  658. cache_keys = f"frexp({x})[0]", f"frexp({x})[1]"
  659. if all(cache_key in V.kernel.cse.cache for cache_key in cache_keys):
  660. return tuple(V.kernel.cse.cache[cache_key] for cache_key in cache_keys)
  661. code = BracesBuffer()
  662. exponent = V.kernel.cse.newvar()
  663. mantissa = V.kernel.cse.newvar()
  664. code.writeline(f"int32_t {exponent};")
  665. code.writeline(f"auto {mantissa} = std::frexp({x}, &{exponent});")
  666. V.kernel.compute.splice(code)
  667. cse_vars = (mantissa, exponent)
  668. for cache_key, cse_var in zip(cache_keys, cse_vars):
  669. V.kernel.cse.cache[cache_key] = cse_var
  670. return mantissa, exponent
  671. @staticmethod
  672. def hypot(x, y):
  673. return f"std::hypot({x}, {y})"
  674. @staticmethod
  675. def log10(x):
  676. return f"std::log10({x})"
  677. @staticmethod
  678. def log2(x):
  679. return f"std::log2({x})"
  680. @staticmethod
  681. def nextafter(x, y):
  682. return f"std::nextafter({x}, {y})"
  683. @staticmethod
  684. def relu(x):
  685. bug = config.cpp.inject_relu_bug_TESTING_ONLY
  686. if bug == "compile_error":
  687. return "compile error!"
  688. elif bug == "runtime_error":
  689. return f"{x}; throw 1"
  690. elif bug == "accuracy":
  691. return f"{x} + decltype({x})(1)"
  692. elif bug is None:
  693. return f"std::max({x}, decltype({x})(0))"
  694. else:
  695. raise AssertionError(
  696. f"unrecognized config cpp.inject_relu_bug_TESTING_ONLY = {bug!r}"
  697. )
  698. @staticmethod
  699. def minimum(a, b):
  700. return f"min_propagate_nan({a}, {b})"
  701. @staticmethod
  702. def maximum(a, b):
  703. return f"max_propagate_nan({a}, {b})"
  704. @staticmethod
  705. def where(a, b, c):
  706. return f"{a} ? {b} : {c}"
  707. @staticmethod
  708. def mod(a, b):
  709. return f"mod({a}, {b})"
  710. @staticmethod
  711. def constant(val, dtype):
  712. opt_ctx: OptimizationContext = get_current_node_opt_ctx()
  713. assert opt_ctx and opt_ctx.dtype is not None, opt_ctx
  714. dtype = opt_ctx.dtype
  715. if dtype in DTYPE_LOWP_FP:
  716. # Since load promotes all half-precision inputs to float, constants
  717. # must be promoted as well
  718. dtype = torch.float32
  719. return value_to_cpp(val, DTYPE_TO_CPP[dtype])
  720. @staticmethod
  721. def index_expr(expr, dtype):
  722. opt_ctx: OptimizationContext = get_current_node_opt_ctx()
  723. assert opt_ctx and opt_ctx.dtype is not None
  724. dtype = opt_ctx.dtype
  725. idx_str = cexpr(V.kernel.rename_indexing(expr))
  726. var = V.kernel.cse.generate(
  727. V.kernel.compute, idx_str, bounds=get_bounds_index_expr(expr)
  728. )
  729. return ops.to_dtype(var, dtype)
  730. @staticmethod
  731. def masked(mask, body, other):
  732. code = BracesBuffer()
  733. # Write masked operation into a lambda
  734. body_var = V.kernel.cse.newvar()
  735. code.writeline(f"auto {body_var} = [&]")
  736. with V.kernel.swap_buffers(code), code.indent():
  737. result = body()
  738. code.writeline(f"return {result};")
  739. code.writeline(";")
  740. V.kernel.compute.splice(code)
  741. # Use the lambda's return type as the type of other
  742. other_code = value_to_cpp(other, f"decltype({body_var}())")
  743. return f"{mask} ? {body_var}() : {other_code}"
  744. @staticmethod
  745. def logical_and(a, b):
  746. return f"{a} && {b}"
  747. @staticmethod
  748. def logical_not(a):
  749. return f"!{a}"
  750. @staticmethod
  751. def logical_or(a, b):
  752. return f"{a} || {b}"
  753. @staticmethod
  754. def logical_xor(a, b):
  755. return f"{a} != {b}"
  756. @staticmethod
  757. def bitwise_and(a, b):
  758. return f"decltype({a})({a} & {b})"
  759. @staticmethod
  760. def bitwise_not(a):
  761. return f"decltype({a})(~{a})"
  762. @staticmethod
  763. def bitwise_or(a, b):
  764. return f"decltype({a})({a} | {b})"
  765. @staticmethod
  766. def bitwise_xor(a, b):
  767. return f"decltype({a})({a} ^ {b})"
  768. @staticmethod
  769. def bitwise_left_shift(a, b):
  770. return f"decltype({a})({a} << {b})"
  771. @staticmethod
  772. def bitwise_right_shift(a, b):
  773. return f"decltype({a})({a} >> {b})"
  774. @staticmethod
  775. def rand(seed: sympy.Expr, offset: sympy.Expr):
  776. return f"normalized_rand_cpu({seed}, {offset})"
  777. @staticmethod
  778. def randn(seed: sympy.Expr, offset: sympy.Expr):
  779. return f"randn_cpu({seed}, {offset})"
  780. @staticmethod
  781. def randint64(seed: sympy.Expr, offset: sympy.Expr, low, high):
  782. return f"randint64_cpu({seed}, {offset}, {low}, {high})"
  783. @staticmethod
  784. def sigmoid(x):
  785. return f"decltype({x})(1) / (decltype({x})(1) + std::exp(-{x}))"
  786. @staticmethod
  787. def sign(x):
  788. code = BracesBuffer()
  789. scalar_zero = f"decltype({x})(0)"
  790. scalar_one = f"decltype({x})(1)"
  791. code.writeline("[&]()")
  792. with code.indent():
  793. code.writeline(f"auto left = {x} > 0 ? {scalar_one} : {scalar_zero};")
  794. code.writeline(f"auto right = {x} < 0 ? {scalar_one} : {scalar_zero};")
  795. code.writeline("return left - right;")
  796. code.writeline("()")
  797. return code
  798. CppOverrides._initialize_pointwise_overrides("cpp")
  799. class CppVecOverrides(CppOverrides):
  800. """Map element-wise ops to aten vectorization C++"""
  801. def __new__(cls, *args, **kargs):
  802. self = super().__new__(cls)
  803. def wrap(func):
  804. # `CppVecKernel` generates both scalar ops and vector ops according to
  805. # whether the inputs are scalars or vectors while all ops in `CppVecOverrides`
  806. # (except for some ops explained below) assume the inputs are vectors. We wrap the ops in
  807. # `CppVecOverrides` to broadcast scalar inputs to vectors if needed or fallback to
  808. # `CppOverrides` when all inputs are scalars.
  809. #
  810. # Notes on ops handled separately in their own functions:
  811. # `ops.masked`:
  812. # needs recursive handling of masked body.
  813. # `ops.index_expr`:
  814. # needs to further analyze the dependency of the index expression on
  815. # the tiling itervar.
  816. def wrapper(*args, **kwargs):
  817. scalars = [
  818. arg
  819. for arg in args
  820. if isinstance(arg, (int, sympy.Expr))
  821. or (isinstance(arg, CppCSEVariable) and not arg.is_vec)
  822. ]
  823. vectors = [
  824. arg
  825. for arg in args
  826. if isinstance(arg, CppCSEVariable) and arg.is_vec
  827. ]
  828. new_args = list(args)
  829. if scalars and vectors:
  830. # broadcast scalar args to vector if needed
  831. new_args = []
  832. vec_dtype = vectors[0].dtype
  833. for arg in args:
  834. if isinstance(arg, (int, sympy.Expr)):
  835. arg_dtype = torch.int64
  836. opt_ctx: OptimizationContext = get_current_node_opt_ctx()
  837. assert opt_ctx
  838. if opt_ctx.dtype is not None:
  839. arg_dtype = opt_ctx.dtype
  840. if isinstance(arg, sympy.Expr) and not arg.is_number:
  841. arg = ops.index_expr(arg, arg_dtype)
  842. else:
  843. arg = ops.constant(arg, arg_dtype)
  844. arg = arg.value if isinstance(arg, OpsValue) else arg
  845. if isinstance(arg, CppCSEVariable) and not arg.is_vec:
  846. assert isinstance(V.kernel, CppVecKernel)
  847. # align scalar data type to the vector for binary ops
  848. if len(args) == 2 and arg.dtype != vec_dtype:
  849. arg = ops.to_dtype(arg, vec_dtype)
  850. arg = arg.value if isinstance(arg, OpsValue) else arg
  851. # See NOTE [dtype of CppCSEVariable]: we have to fix arg.dtype since
  852. # the dtype from optimization context could be wrong.
  853. assert isinstance(arg, CppCSEVariable)
  854. arg.dtype = vec_dtype
  855. new_arg = V.kernel.broadcast(arg)
  856. new_args.append(new_arg)
  857. else:
  858. new_args.append(arg)
  859. if vectors:
  860. return func(*new_args, **kwargs)
  861. else:
  862. # fallback to scalar ops
  863. scalar_ops = super(CppVecOverrides, self)
  864. scalar_func = getattr(
  865. scalar_ops, func.__name__, scalar_ops.__getattr__(func.__name__) # type: ignore[attr-defined]
  866. )
  867. assert scalar_func is not None
  868. return scalar_func(*args, **kwargs)
  869. return wrapper
  870. for name, method in vars(CppVecOverrides).items():
  871. if getattr(method, "__class__", None) == staticmethod and name not in [
  872. "masked",
  873. "index_expr",
  874. ]:
  875. setattr(self, name, wrap(method.__func__))
  876. return self
  877. @staticmethod
  878. def add(a, b):
  879. return f"{a} + {b}"
  880. @staticmethod
  881. def sub(a, b):
  882. return f"{a} - {b}"
  883. @staticmethod
  884. def mul(a, b):
  885. return f"{a} * {b}"
  886. @staticmethod
  887. def truediv(a, b):
  888. return f"{a} / {b}"
  889. @staticmethod
  890. def abs(x):
  891. return f"{x}.abs()"
  892. @staticmethod
  893. def sin(x):
  894. return f"{x}.sin()"
  895. @staticmethod
  896. def cos(x):
  897. return f"{x}.cos()"
  898. @staticmethod
  899. def exp(x):
  900. return f"{x}.exp()"
  901. @staticmethod
  902. def exp2(x):
  903. return f"{x}.exp2()"
  904. @staticmethod
  905. def expm1(x):
  906. # decompose for a better performance
  907. vec_one = f"decltype({x})(1)"
  908. return f"{x}.exp() - {vec_one}"
  909. @staticmethod
  910. def erf(x):
  911. return f"{x}.erf()"
  912. @staticmethod
  913. def erfc(x):
  914. return f"{x}.erfc()"
  915. @staticmethod
  916. def erfinv(x):
  917. return f"{x}.erfinv()"
  918. @staticmethod
  919. def sqrt(x):
  920. return f"{x}.sqrt()"
  921. @staticmethod
  922. def eq(x, y):
  923. assert isinstance(V.kernel, CppVecKernel)
  924. assert isinstance(x, CppCSEVariable)
  925. assert x.dtype is not None
  926. return f"{V.kernel._get_mask_type(x.dtype)}({x} == {y})"
  927. @staticmethod
  928. def ne(x, y):
  929. assert isinstance(V.kernel, CppVecKernel)
  930. assert isinstance(x, CppCSEVariable)
  931. assert x.dtype is not None
  932. return f"{V.kernel._get_mask_type(x.dtype)}({x} != {y})"
  933. @staticmethod
  934. def lt(x, y):
  935. assert isinstance(V.kernel, CppVecKernel)
  936. assert isinstance(x, CppCSEVariable)
  937. assert x.dtype is not None
  938. return f"{V.kernel._get_mask_type(x.dtype)}({x} < {y})"
  939. @staticmethod
  940. def gt(x, y):
  941. assert isinstance(V.kernel, CppVecKernel)
  942. assert isinstance(x, CppCSEVariable)
  943. assert x.dtype is not None
  944. return f"{V.kernel._get_mask_type(x.dtype)}({x} > {y})"
  945. @staticmethod
  946. def le(x, y):
  947. assert isinstance(V.kernel, CppVecKernel)
  948. assert isinstance(x, CppCSEVariable)
  949. assert x.dtype is not None
  950. return f"{V.kernel._get_mask_type(x.dtype)}({x} <= {y})"
  951. @staticmethod
  952. def ge(x, y):
  953. assert isinstance(V.kernel, CppVecKernel)
  954. assert isinstance(x, CppCSEVariable)
  955. assert x.dtype is not None
  956. return f"{V.kernel._get_mask_type(x.dtype)}({x} >= {y})"
  957. @staticmethod
  958. def and_(x, y):
  959. return f"{x} & {y}"
  960. @staticmethod
  961. def rsqrt(x):
  962. return f"{x}.rsqrt()"
  963. @staticmethod
  964. def pow(a, b):
  965. return f"{a}.pow({b})"
  966. @staticmethod
  967. def log(x):
  968. return f"{x}.log()"
  969. @staticmethod
  970. def round(x):
  971. return f"{x}.round()"
  972. @staticmethod
  973. def floor(x):
  974. return f"{x}.floor()"
  975. @staticmethod
  976. def ceil(x):
  977. return f"{x}.ceil()"
  978. @staticmethod
  979. def trunc(x):
  980. return f"{x}.trunc()"
  981. @staticmethod
  982. def fmod(a, b):
  983. return f"{a}.fmod({b})"
  984. @staticmethod
  985. def lgamma(x):
  986. return f"{x}.lgamma()"
  987. @staticmethod
  988. def logical_and(a, b):
  989. return f"{a} & {b}"
  990. @staticmethod
  991. def logical_not(a):
  992. return f"~{a}"
  993. @staticmethod
  994. def logical_or(a, b):
  995. return f"{a} | {b}"
  996. @staticmethod
  997. def logical_xor(a, b):
  998. return f"{a} ^ {b}"
  999. @staticmethod
  1000. def tan(a):
  1001. return f"{a}.tan()"
  1002. @staticmethod
  1003. def tanh(a):
  1004. vec_one = f"decltype({a})(1)"
  1005. vec_two = f"decltype({a})(2)"
  1006. vec_minus_two = f"decltype({a})(-2)"
  1007. return f"{vec_two} / ({vec_one} + ({vec_minus_two} * {a}).exp()) - {vec_one}"
  1008. @staticmethod
  1009. def reciprocal(a):
  1010. return f"{a}.reciprocal()"
  1011. @staticmethod
  1012. def atan(x):
  1013. return f"{x}.atan()"
  1014. @staticmethod
  1015. def acos(x):
  1016. return f"{x}.acos()"
  1017. @staticmethod
  1018. def asin(x):
  1019. return f"{x}.asin()"
  1020. @staticmethod
  1021. def cosh(x):
  1022. return f"{x}.cosh()"
  1023. @staticmethod
  1024. def sinh(x):
  1025. return f"{x}.sinh()"
  1026. @staticmethod
  1027. def log10(x):
  1028. return f"{x}.log10()"
  1029. @staticmethod
  1030. def log2(x):
  1031. return f"{x}.log2()"
  1032. @staticmethod
  1033. def nextafter(x, y):
  1034. return f"{x}.nextafter({y})"
  1035. @staticmethod
  1036. def copysign(a, b):
  1037. return f"{a}.copysign({b})"
  1038. @staticmethod
  1039. def atan2(a, b):
  1040. return f"{a}.atan2({b})"
  1041. @staticmethod
  1042. def hypot(a, b):
  1043. return f"{a}.hypot({b})"
  1044. @staticmethod
  1045. def atanh(x):
  1046. # For real x, atanh(x) = 1/2 * log((1+x)/(1-x))
  1047. vec_one = f"decltype({x})(1)"
  1048. vec_one_half = f"decltype({x})(0.5)"
  1049. return f"{vec_one_half} * (({vec_one} + {x})/({vec_one} - {x})).log()"
  1050. @staticmethod
  1051. def asinh(x):
  1052. # For real x, asinh(x) = log(x + sqrt(1 + x**2))
  1053. vec_one = f"decltype({x})(1)"
  1054. return f"({x} + ({vec_one} + {x}*{x}).sqrt()).log()"
  1055. @staticmethod
  1056. def acosh(x):
  1057. return f"{x}.acosh()"
  1058. @staticmethod
  1059. def relu(x):
  1060. bug = config.cpp.inject_relu_bug_TESTING_ONLY
  1061. if bug == "compile_error":
  1062. return "compile error!"
  1063. elif bug == "runtime_error":
  1064. return f"{x}; throw 1"
  1065. elif bug == "accuracy":
  1066. return f"{x} + decltype({x})(1)"
  1067. elif bug is None:
  1068. return f"at::vec::clamp_min({x}, decltype({x})(0))"
  1069. else:
  1070. raise AssertionError(
  1071. f"unrecognized config cpp.inject_relu_bug_TESTING_ONLY = {bug!r}"
  1072. )
  1073. # TODO: this seems to be dead
  1074. @staticmethod
  1075. def sigmoid(x):
  1076. return f"decltype({x})(1)/(decltype({x})(1) + {x}.neg().exp())"
  1077. @staticmethod
  1078. def neg(x):
  1079. return f"{x}.neg()"
  1080. @staticmethod
  1081. def floordiv(a, b):
  1082. # a and b are integer type
  1083. _t = f"decltype({a})"
  1084. quot = f"{a} / {b}"
  1085. has_rem = f"({a} % {b} != {_t}(0))"
  1086. is_neg = f"(({a} < {_t}(0)) != ({b} < {_t}(0)))"
  1087. return f"{_t}::blendv({quot}, {quot} - {_t}(1), {has_rem} & {is_neg})"
  1088. @staticmethod
  1089. def truncdiv(a, b):
  1090. # a and b are integer type
  1091. return f"{a} / {b}"
  1092. @staticmethod
  1093. def minimum(a, b):
  1094. return f"at::vec::minimum({a}, {b})"
  1095. @staticmethod
  1096. def maximum(a, b):
  1097. return f"at::vec::maximum({a}, {b})"
  1098. @staticmethod
  1099. def square(a):
  1100. return f"{a} * {a}"
  1101. @staticmethod
  1102. def where(a, b, c):
  1103. assert isinstance(V.kernel, CppVecKernel)
  1104. if b.dtype == torch.bool:
  1105. assert c.dtype == torch.bool
  1106. blendv_a = f"{V.kernel._get_mask_cast(a, torch.float)}"
  1107. blendv_b = f"{V.kernel._get_mask_cast(b, torch.float)}"
  1108. blendv_c = f"{V.kernel._get_mask_cast(c, torch.float)}"
  1109. return f"decltype({b})::blendv({blendv_c}, {blendv_b}, {blendv_a})"
  1110. else:
  1111. return f"decltype({b})::blendv({c}, {b}, {V.kernel._get_mask_cast(a, b.dtype)})"
  1112. @staticmethod
  1113. def sign(x):
  1114. code = BracesBuffer()
  1115. vec_zero = f"decltype({x})(0)"
  1116. vec_one = f"decltype({x})(1)"
  1117. blendv_l = f"decltype({x})::blendv({vec_zero}, {vec_one}, {vec_zero} < {x})"
  1118. blendv_r = f"decltype({x})::blendv({vec_zero}, {vec_one}, {x} < {vec_zero})"
  1119. code.writeline("[&]()")
  1120. with code.indent():
  1121. code.writeline(f"auto left = {blendv_l};")
  1122. code.writeline(f"auto right = {blendv_r};")
  1123. code.writeline("return left - right;")
  1124. code.writeline("()")
  1125. return code
  1126. @staticmethod
  1127. def to_dtype(x, dtype, src_dtype=None):
  1128. assert dtype in [
  1129. torch.bool,
  1130. torch.float,
  1131. torch.bfloat16,
  1132. torch.float16,
  1133. torch.uint8,
  1134. torch.int8,
  1135. torch.int32,
  1136. torch.int64,
  1137. ], f"{__name__} does not support {dtype}"
  1138. node: torch.fx.Node = V.interpreter.current_node
  1139. assert node and isinstance(node, torch.fx.Node)
  1140. opt_ctx_x = get_opt_ctx(node.args[1])
  1141. assert opt_ctx_x
  1142. assert opt_ctx_x.dtype is not None
  1143. assert isinstance(V.kernel, CppVecKernel)
  1144. src_dtype = opt_ctx_x.dtype
  1145. src_cpp_type = DTYPE_TO_CPP[src_dtype]
  1146. src_num_vectors = V.kernel._get_num_vectors(src_dtype)
  1147. dst_cpp_type = DTYPE_TO_CPP[dtype]
  1148. dst_num_vectors = V.kernel._get_num_vectors(dtype)
  1149. if src_dtype != torch.bool and dtype == torch.bool:
  1150. return f"{V.kernel._get_mask_type(src_dtype)}::from<{src_cpp_type},{src_num_vectors}>({x})"
  1151. if opt_ctx_x.dtype == torch.bool and dtype != torch.bool:
  1152. return f"{x}.to<{dst_cpp_type},{dst_num_vectors}>()"
  1153. if src_dtype != dtype:
  1154. if src_num_vectors == dst_num_vectors == 1:
  1155. return f"at::vec::convert<{dst_cpp_type}>({x})"
  1156. else:
  1157. return f"at::vec::convert<{dst_cpp_type},{dst_num_vectors},{src_cpp_type},{src_num_vectors}>({x})"
  1158. return f"({x})"
  1159. @staticmethod
  1160. def log1p(x):
  1161. bug = config.cpp.inject_log1p_bug_TESTING_ONLY
  1162. if bug == "accuracy":
  1163. return f"{x} + decltype({x})(1)"
  1164. elif bug is None:
  1165. return f"{x}.log1p()"
  1166. else:
  1167. raise AssertionError(
  1168. f"unrecognized config cpp.inject_log1p_bug_TESTING_ONLY = {bug!r}"
  1169. )
  1170. @staticmethod
  1171. def masked(mask, body, other):
  1172. assert isinstance(V.kernel, CppVecKernel)
  1173. code = BracesBuffer()
  1174. var = V.kernel.cse.newvar()
  1175. with V.kernel.masked(mask) as new_mask:
  1176. code.writeline(f"auto {var} = [&]")
  1177. with V.kernel.swap_buffers(code), code.indent():
  1178. result = body()
  1179. code.writeline(f"return {result};")
  1180. code.writeline(";")
  1181. V.kernel.compute.splice(code)
  1182. dtype = result.dtype
  1183. body_code = f"{var}()"
  1184. body_code_vec = (
  1185. body_code
  1186. if result.is_vec
  1187. else f"{V.kernel._get_vec_type(dtype)}({body_code})"
  1188. )
  1189. other_code = value_to_cpp(other, DTYPE_TO_CPP[dtype])
  1190. # loading bool as VecMask<float, N>
  1191. other_code_vec = (
  1192. f"{V.kernel._get_mask_type()}::from({other_code})"
  1193. if dtype == torch.bool
  1194. else f"{V.kernel._get_vec_type(dtype)}({other_code})"
  1195. )
  1196. assert isinstance(new_mask, CppCSEVariable), new_mask
  1197. if new_mask.is_vec:
  1198. code = BracesBuffer()
  1199. code.writeline("[&]")
  1200. with V.kernel.swap_buffers(code), code.indent():
  1201. code.writeline(f"if ({new_mask}.all_zero())")
  1202. with code.indent():
  1203. code.writeline(f"return {other_code_vec};")
  1204. code.writeline("else")
  1205. with code.indent():
  1206. # Create cse variable to reuse kernel.overrides.where
  1207. body_vec_var = V.kernel.cse.generate(
  1208. V.kernel.compute,
  1209. body_code_vec,
  1210. )
  1211. other_vec_var = V.kernel.cse.generate(
  1212. V.kernel.compute,
  1213. other_code_vec,
  1214. )
  1215. assert isinstance(body_vec_var, CppCSEVariable), body_vec_var
  1216. assert isinstance(other_vec_var, CppCSEVariable), other_vec_var
  1217. body_vec_var.dtype = dtype
  1218. other_vec_var.dtype = dtype
  1219. code.writeline(
  1220. f"return {V.kernel.overrides.where(new_mask, body_vec_var, other_vec_var)};"
  1221. )
  1222. code.writeline("()")
  1223. csevar = V.kernel.cse.generate(
  1224. V.kernel.compute,
  1225. code,
  1226. )
  1227. elif result.is_vec:
  1228. csevar = V.kernel.cse.generate(
  1229. V.kernel.compute, f"{mask} ? {body_code_vec} : {other_code_vec}"
  1230. )
  1231. else:
  1232. csevar = V.kernel.cse.generate(
  1233. V.kernel.compute, f"{mask} ? {body_code} : {other_code}"
  1234. )
  1235. # `result` is explicitly added to the args for correct propagation
  1236. # of relevant itervars and vectorization status.
  1237. csevar.update_on_args("masked", (mask, body, other, result), {})
  1238. return csevar
  1239. @staticmethod
  1240. def index_expr(expr, dtype):
  1241. opt_ctx: OptimizationContext = get_current_node_opt_ctx()
  1242. assert opt_ctx and opt_ctx.dtype is not None
  1243. dtype = opt_ctx.dtype
  1244. assert isinstance(V.kernel, CppVecKernel)
  1245. index = V.kernel.rename_indexing(expr)
  1246. tiling_var = V.kernel.itervars[V.kernel.tiling_idx]
  1247. stride = V.kernel._try_get_const_stride(index, tiling_var)
  1248. if stride == 0:
  1249. return CppOverrides.index_expr(expr, dtype)
  1250. elif stride is not None:
  1251. idx = V.kernel.cse.generate(
  1252. V.kernel.compute, cexpr(index), bounds=get_bounds_index_expr(expr)
  1253. )
  1254. value = ops.to_dtype(idx, dtype)
  1255. if isinstance(value, OpsValue):
  1256. value = value.value
  1257. csevar = V.kernel.arange(value, stride)
  1258. else:
  1259. csevar = V.kernel._load_or_store_non_contiguous( # type: ignore[assignment]
  1260. None, index, dtype, V.kernel.compute
  1261. )
  1262. csevar.update_on_args("index_expr", (expr, dtype), {})
  1263. return csevar
  1264. CppVecOverrides._initialize_pointwise_overrides("cppvec")
  1265. class CppTile2DOverrides(CppVecOverrides):
  1266. @staticmethod
  1267. def index_expr(expr, dtype):
  1268. assert isinstance(V.kernel, CppTile2DKernel)
  1269. expr = V.kernel.transform_indexing(expr)
  1270. return CppVecOverrides.index_expr(expr, dtype)
  1271. class CppKernel(Kernel):
  1272. overrides = CppOverrides # type: ignore[assignment]
  1273. sexpr = cexpr
  1274. newvar_prefix = "auto "
  1275. suffix = ";"
  1276. def __init__(self, args, num_threads):
  1277. super().__init__(args)
  1278. self.call_ranges: Optional[Tuple[sympy.Expr, ...]] = None
  1279. self.ranges: List[sympy.Expr] = []
  1280. self.itervars: List[sympy.Symbol] = []
  1281. self.reduction_depth = None
  1282. self.reduction_prefix = IndentedBuffer()
  1283. self.reduction_suffix = IndentedBuffer()
  1284. self.parallel_reduction_prefix = IndentedBuffer()
  1285. self.parallel_reduction_suffix = IndentedBuffer()
  1286. self.local_reduction_init = IndentedBuffer()
  1287. self.local_reduction_stores = IndentedBuffer()
  1288. self.is_reduction = False
  1289. self.non_parallel_reduction_prefix = IndentedBuffer()
  1290. self.reduction_cse = CSE(self.newvar_prefix, self.suffix, name_prefix="tmp_acc")
  1291. self.preloads = IndentedBuffer()
  1292. self.poststores = IndentedBuffer()
  1293. self.num_threads = num_threads # num_threads the kernel specialized for
  1294. self.reduction_omp_dec: Dict[Tuple[str, str], str] = {}
  1295. def _gen_parallel_reduction_buffers(
  1296. self,
  1297. acc,
  1298. acc_type,
  1299. reduction_type,
  1300. dtype,
  1301. reduction_combine_fn=reduction_combine,
  1302. reduction_init_fn=reduction_init,
  1303. welford_weight_reciprocal_vec_fn=None,
  1304. ):
  1305. if config.cpp.dynamic_threads and not self.parallel_reduction_prefix:
  1306. self.parallel_reduction_prefix.writeline(
  1307. "int max_threads = omp_get_max_threads();"
  1308. )
  1309. acc_local = f"{acc}_local"
  1310. num_threads = (
  1311. "max_threads" if config.cpp.dynamic_threads else parallel_num_threads()
  1312. )
  1313. acc_per_thread = f"{acc}_arr[{num_threads}]"
  1314. acc_local_in_array = acc_per_thread.replace(f"[{num_threads}]", "[tid]")
  1315. self.local_reduction_init.writeline(
  1316. f"{acc_type} {acc_local} = {reduction_init_fn(reduction_type, dtype)};"
  1317. )
  1318. self.parallel_reduction_prefix.writeline(f"{acc_type} {acc_per_thread};")
  1319. self.parallel_reduction_prefix.writelines(
  1320. [
  1321. f"for (int tid = 0; tid < {num_threads}; tid++)",
  1322. "{",
  1323. f" {acc_local_in_array} = {reduction_init_fn(reduction_type, dtype)};",
  1324. "}",
  1325. ],
  1326. )
  1327. self.local_reduction_stores.writelines(
  1328. [
  1329. f"{acc_local_in_array} = {acc_local};",
  1330. ]
  1331. )
  1332. self.parallel_reduction_suffix.writelines(
  1333. [
  1334. f"for (int tid = 0; tid < {num_threads}; tid++)",
  1335. "{",
  1336. f" {acc} = {reduction_combine_fn(reduction_type, acc, acc_local_in_array)};",
  1337. "}",
  1338. ],
  1339. )
  1340. if (
  1341. reduction_type == "welford_reduce"
  1342. and welford_weight_reciprocal_vec_fn
  1343. and hasattr(self, "weight_recp_vec_range")
  1344. and "vec" in f"{acc_type}"
  1345. ):
  1346. self.local_reduction_init.writeline(
  1347. welford_weight_reciprocal_vec_fn(dtype, num_threads)
  1348. )
  1349. def get_reduction_var_pattern(self, line: str):
  1350. return re.search("tmp_acc[0-9]+", line)
  1351. def update_stores_with_parallel_reduction(self):
  1352. for i, line in enumerate(self.stores._lines):
  1353. if isinstance(line, str):
  1354. m = self.get_reduction_var_pattern(line)
  1355. if m:
  1356. var_name = m.group(0)
  1357. self.stores._lines[i] = line.replace(var_name, f"{var_name}_local")
  1358. @contextlib.contextmanager
  1359. def masked(self, mask):
  1360. """Context manager to add an additional mask to loads and stores."""
  1361. prior = self._load_mask
  1362. if prior:
  1363. mask = ops.and_(mask, prior)
  1364. if isinstance(mask, OpsValue):
  1365. mask = mask.value
  1366. assert isinstance(mask, CppCSEVariable)
  1367. # see NOTE [dtype of CppCSEVariable]
  1368. # mask's dtype should be bool
  1369. mask.dtype = torch.bool
  1370. self._load_mask = mask
  1371. try:
  1372. yield mask
  1373. finally:
  1374. self._load_mask = prior
  1375. def cache_fp32_cse_var_before_lowp_store(self, var_to_store):
  1376. """
  1377. https://github.com/pytorch/pytorch/issues/115260
  1378. For FusedSchedulerNode[node1, node2], the node2 loads what node1 stores and the buffer is
  1379. in low-precision floating point data type. When the output of node1 also serves as the output of the
  1380. kernel, the result of nodes would be different from the case when output of node1 is not the output
  1381. of the kernel (where we don't need to insert `to_dtype` for legalization). To address the problem, on
  1382. storing the lowp node1 output, we also add the inverse dtype conversion to high precision data type
  1383. to the cse cache.
  1384. Example (pseudo code):
  1385. node1_output = ...
  1386. node1_output_lowp = to_dtype(node1_output, dtype=torch.bfloat16)
  1387. store(buf, node1_output_lowp)
  1388. node2_input_lowp = load(buf)
  1389. node2_input = to_dtype(node2_input_lowp, dtype=torch.float)
  1390. Without cse cache trick:
  1391. node1_output = ...
  1392. node1_output_lowp = to_dtype(node1_output, dtype=torch.bfloat16)
  1393. store(buf, node1_output_lowp)
  1394. node2_input_lowp = node_output_lowp # hit store cache
  1395. node2_input = to_dtype(node2_input_lowp, dtype=torch.float)
  1396. With cse cache trick:
  1397. node1_output = ...
  1398. node1_output_lowp = to_dtype(node1_output, dtype=torch.bfloat16)
  1399. # also add `to_dtype(node1_input_lowp, dtype=torch.float)` -> `node1_output` to cse cache
  1400. store(buf, node1_output_lowp)
  1401. node2_input_lowp = node_output_lowp # hit store cache
  1402. node2_input = node1_output # hit cse cache
  1403. """
  1404. if var_to_store.dtype not in DTYPE_LOWP_FP:
  1405. # only need to cache fp32 cse var while var_to_store is lowp data
  1406. return
  1407. def find_fp32_var(var, cache):
  1408. fp32_cse_var = None
  1409. fp32_cse_var_name = None
  1410. for expr, cse_var in cache.items():
  1411. if cse_var == var:
  1412. if is_to_lowp_dtype(expr):
  1413. m = re.search(r"tmp\d+", expr)
  1414. if m is not None:
  1415. fp32_cse_var_name = m.group()
  1416. if fp32_cse_var_name:
  1417. for cse_var in cache.values():
  1418. if cse_var.name == fp32_cse_var_name:
  1419. fp32_cse_var = cse_var
  1420. break
  1421. assert fp32_cse_var is not None
  1422. return fp32_cse_var
  1423. fp32_var = find_fp32_var(var_to_store, self.cse.cache)
  1424. if fp32_var:
  1425. self.cse.cache[get_lowp_to_fp32_expr(var_to_store, self)] = fp32_var
  1426. def scale_index_with_offset(
  1427. self, index: sympy.Expr, scale=1, itervar_idx=-1, offset=0
  1428. ):
  1429. var = self.itervars[itervar_idx]
  1430. replacement = {var: var * scale + offset}
  1431. new_index = sympy_subs(index, replacement)
  1432. return new_index
  1433. def index_to_str(self, index: sympy.Expr) -> str:
  1434. """
  1435. Convert an index expr to a string that can be used in cpp code.
  1436. e.g. a sympy expression "s2" may actually appear as "ks1" in the cpp kernel.
  1437. """
  1438. return cexpr(self.rename_indexing(index))
  1439. def index_indirect_depends_on(self, index: sympy.Expr, itervar: sympy.Symbol):
  1440. """
  1441. Check if an index has free symbol CppCSEVariable that depends on `itervar`.
  1442. """
  1443. return any(
  1444. self.cse.varname_map[s.name].depends_on(itervar) # type: ignore[attr-defined]
  1445. for s in index.free_symbols
  1446. if s.name in self.cse.varname_map # type: ignore[attr-defined]
  1447. and isinstance(self.cse.varname_map[s.name], CppCSEVariable) # type: ignore[attr-defined]
  1448. )
  1449. def index_depends_on(self, index: sympy.Expr, itervar: sympy.Symbol):
  1450. return itervar in index.free_symbols or self.index_indirect_depends_on(
  1451. index, itervar
  1452. )
  1453. def var_ranges(self):
  1454. return dict(zip(self.itervars, self.ranges))
  1455. def check_bounds(
  1456. self,
  1457. expr: sympy.Expr,
  1458. size: sympy.Expr,
  1459. lower: bool,
  1460. upper: bool,
  1461. ):
  1462. if not (lower or upper):
  1463. return
  1464. indirect = free_symbol_is_type(expr, SymT.TMP)
  1465. if indirect:
  1466. # indexing in compute
  1467. csevar = ops.index_expr(expr, torch.int32).value
  1468. buffer = V.kernel.compute
  1469. else:
  1470. # indexing in loads
  1471. prior_compute = V.kernel.compute
  1472. try:
  1473. V.kernel.compute = self.loads
  1474. csevar = ops.index_expr(expr, torch.int32).value
  1475. finally:
  1476. V.kernel.compute = prior_compute
  1477. buffer = self.loads
  1478. size_str = V.kernel.sexpr(self.rename_indexing(size)) if upper else None
  1479. line = self.indirect_assert(csevar, "0" if lower else None, size_str)
  1480. self.cse.generate(buffer, line, assignment=False)
  1481. def load(self, name: str, index: sympy.Expr):
  1482. var = self.args.input(name)
  1483. index = self.rename_indexing(index)
  1484. line = f"{var}[{cexpr_index(index)}]"
  1485. if V.graph.get_dtype(name) in [torch.float16]:
  1486. line = f"static_cast<float>({line})"
  1487. csevar = self.cse.generate(self.loads, line)
  1488. csevar.update_on_args("load", (name, index), {})
  1489. return csevar
  1490. def store(self, name, index, value, mode=None):
  1491. assert "buf" in name
  1492. var = self.args.output(name)
  1493. self.cache_fp32_cse_var_before_lowp_store(value)
  1494. index = self.rename_indexing(index)
  1495. if mode is None:
  1496. line = f"{var}[{cexpr_index(index)}] = {value};"
  1497. elif mode == "atomic_add":
  1498. if not config.cpp.dynamic_threads and self.num_threads == 1:
  1499. line = f"{var}[{cexpr_index(index)}] += {value};"
  1500. else:
  1501. dtype = V.graph.get_dtype(name)
  1502. # mirroring static_cast<float>(...) in load:
  1503. value = f"static_cast<{DTYPE_TO_CPP[dtype]}>({value})"
  1504. line = f"atomic_add(&{var}[{cexpr_index(index)}], {value});"
  1505. else:
  1506. raise NotImplementedError(f"store mode={mode}")
  1507. self.stores.writeline(DeferredLine(name, line))
  1508. def reduction(self, dtype, src_dtype, reduction_type, value):
  1509. argmax_or_argmin = reduction_type in {"argmax", "argmin"}
  1510. reduction_key = src_dtype, reduction_type, value
  1511. if reduction_key in self.reduction_cse.reduction_cache:
  1512. return self.reduction_cse.reduction_cache[reduction_key]
  1513. acc = self.reduction_cse.generate(
  1514. self.loads, f"reduction {reduction_key}", write=False
  1515. )
  1516. self.is_reduction = True
  1517. if argmax_or_argmin:
  1518. prefix, parallel_prefix, local_init = argmax_argmin_prefix(
  1519. reduction_type, src_dtype, acc
  1520. )
  1521. self.local_reduction_init.writelines(local_init)
  1522. self.reduction_prefix.writelines(prefix)
  1523. self.parallel_reduction_prefix.writelines(parallel_prefix)
  1524. compare_op = (
  1525. "greater_or_nan" if reduction_type == "argmax" else "less_or_nan"
  1526. )
  1527. assert self.reduction_depth is not None
  1528. index = self.itervars[self.reduction_depth]
  1529. for i in range(self.reduction_depth + 1, len(self.itervars)):
  1530. index = index * self.ranges[i] + self.itervars[i]
  1531. self.stores.writelines(
  1532. [
  1533. f"if(!({compare_op}({acc}.value, {value}, {acc}.index, {cexpr_index(index)}))) {{",
  1534. f" {acc}.index = {cexpr_index(index)}; {acc}.value = {value};",
  1535. "}",
  1536. ]
  1537. )
  1538. acc_local = f"{acc}_local"
  1539. num_threads = parallel_num_threads()
  1540. acc_per_thread = f"{acc}_arr[{num_threads}]"
  1541. acc_local_in_array = acc_per_thread.replace(f"[{num_threads}]", "[tid]")
  1542. self.parallel_reduction_suffix.writelines(
  1543. [
  1544. f"for (int tid = 0; tid < {num_threads}; tid++)",
  1545. "{",
  1546. f" if(!({compare_op}({acc}.value, {acc_local_in_array}.value, {acc}.index, {acc_local_in_array}.index))) {{",
  1547. f" {acc}.index = {acc_local_in_array}.index; {acc}.value = {acc_local_in_array}.value;",
  1548. " }",
  1549. "}",
  1550. ],
  1551. )
  1552. self.local_reduction_stores.writelines(
  1553. [
  1554. f"{acc_local_in_array} = {acc_local};",
  1555. ]
  1556. )
  1557. else:
  1558. acc_type = reduction_acc_type(reduction_type, dtype)
  1559. self.reduction_prefix.writeline(
  1560. f"{acc_type} {acc} = {reduction_init(reduction_type, dtype)};"
  1561. )
  1562. self.stores.writeline(
  1563. f"{acc} = {reduction_combine(reduction_type, acc, value)};"
  1564. )
  1565. self._gen_parallel_reduction_buffers(acc, acc_type, reduction_type, dtype)
  1566. result = reduction_project(reduction_type, acc)
  1567. self.reduction_cse.reduction_cache[reduction_key] = result
  1568. return result
  1569. def store_reduction(self, name, index, value):
  1570. index = self.rename_indexing(index)
  1571. var = self.args.output(name)
  1572. self.reduction_suffix.writeline(
  1573. DeferredLine(name, f"{var}[{cexpr_index(index)}] = {value};")
  1574. )
  1575. def set_ranges(self, lengths, reduction_lengths):
  1576. if self.call_ranges:
  1577. assert self.call_ranges == tuple(lengths) + tuple(
  1578. reduction_lengths
  1579. ), f"{self.call_ranges} == {tuple(lengths)} + {tuple(reduction_lengths)}"
  1580. assert self.reduction_depth == len(lengths)
  1581. else:
  1582. self.call_ranges = tuple(lengths) + tuple(reduction_lengths)
  1583. self.ranges = [self.rename_indexing(x) for x in self.call_ranges]
  1584. self.itervars = [
  1585. sympy_index_symbol_with_prefix(SymT.XBLOCK, n)
  1586. for n in range(len(self.ranges))
  1587. ]
  1588. self.reduction_depth = len(lengths)
  1589. return (
  1590. self.itervars[: self.reduction_depth],
  1591. self.itervars[self.reduction_depth :],
  1592. )
  1593. def size_hint(self):
  1594. return V.graph.sizevars.size_hint(
  1595. sympy_product(self.call_ranges), fallback=8192
  1596. )
  1597. def codegen_loops_impl(self, loop_nest, code, worksharing):
  1598. threads = parallel_num_threads()
  1599. assert self.call_ranges is not None
  1600. kernels = loop_nest.get_kernels()
  1601. if any(isinstance(kernel, OuterLoopFusedKernel) for kernel in kernels):
  1602. assert len(kernels) == 1
  1603. assert isinstance(kernels[0], OuterLoopFusedKernel)
  1604. par_depth = kernels[0].decide_parallel_depth(
  1605. loop_nest.max_parallel_depth(), threads
  1606. )
  1607. else:
  1608. par_depth = self.decide_parallel_depth(
  1609. loop_nest.max_parallel_depth(), threads
  1610. )
  1611. with contextlib.ExitStack() as stack:
  1612. if par_depth:
  1613. if loop_nest.is_reduction_only():
  1614. # need to close the worksharing scope to define reduction vars outside it
  1615. worksharing.close()
  1616. else:
  1617. worksharing.parallel(threads)
  1618. loop_nest.mark_parallel(par_depth)
  1619. elif threads > 1:
  1620. if worksharing.single():
  1621. stack.enter_context(code.indent())
  1622. def gen_loop_kernel(loop: LoopLevel):
  1623. def is_parallel_reduction(loop):
  1624. root = loop.get_root()
  1625. return root.is_reduction and root.parallel
  1626. kernels = loop.get_kernels()
  1627. assert len(kernels) == 1
  1628. if not isinstance(
  1629. kernels[0], OuterLoopFusedKernel
  1630. ) and is_parallel_reduction(loop):
  1631. kernels[0].update_stores_with_parallel_reduction()
  1632. gen_kernel(kernels[0])
  1633. def gen_kernel(kernel):
  1634. if isinstance(kernel, OuterLoopFusedKernel):
  1635. for loop in kernel.inner:
  1636. if loop.inner:
  1637. gen_loops(loop.inner, loop.is_reduction)
  1638. else:
  1639. with contextlib.ExitStack() as stack:
  1640. # If there is any kernel existing at the final outer loop fusion level,
  1641. # the kernel code should be placed within its respective indent to prevent
  1642. # the duplication of variable definitions.
  1643. stack.enter_context(code.indent())
  1644. gen_loop_kernel(loop)
  1645. else:
  1646. with contextlib.ExitStack() as stack:
  1647. assert kernel
  1648. if hasattr(kernel, "codegen_inner_loops"):
  1649. code.splice(kernel.preloads)
  1650. kernel.codegen_inner_loops(code)
  1651. stack.enter_context(code.indent())
  1652. code.splice(kernel.loads)
  1653. code.splice(kernel.compute)
  1654. code.splice(kernel.stores)
  1655. if hasattr(kernel, "codegen_inner_loops"):
  1656. code.splice(kernel.poststores)
  1657. def get_reduction_code_buffer(loops, buffer="prefix"):
  1658. assert buffer in ("prefix", "suffix", "local")
  1659. for loop in loops:
  1660. for kernel in loop.get_kernels():
  1661. if buffer == "local":
  1662. return (
  1663. kernel.local_reduction_init,
  1664. kernel.local_reduction_stores,
  1665. )
  1666. elif buffer == "suffix":
  1667. suffix = kernel.reduction_suffix
  1668. if loop.parallel:
  1669. suffix = kernel.parallel_reduction_suffix + suffix
  1670. return suffix
  1671. else:
  1672. prefix = kernel.reduction_prefix
  1673. if loop.parallel:
  1674. prefix = prefix + kernel.parallel_reduction_prefix
  1675. else:
  1676. prefix = prefix + kernel.non_parallel_reduction_prefix
  1677. return prefix
  1678. def gen_loops(loops: List[LoopLevel], in_reduction=False):
  1679. with contextlib.ExitStack() as stack_outer:
  1680. local_reduction_init = local_reduction_stores = None
  1681. if loops:
  1682. loop = loops[0]
  1683. if loop.is_reduction and not in_reduction:
  1684. reduction_prefix = get_reduction_code_buffer(loops)
  1685. if reduction_prefix:
  1686. stack_outer.enter_context(code.indent())
  1687. code.splice(reduction_prefix)
  1688. if loop_nest.is_reduction_only() and loop.parallel:
  1689. (
  1690. local_reduction_init,
  1691. local_reduction_stores,
  1692. ) = get_reduction_code_buffer(loops, "local")
  1693. worksharing.parallel(threads)
  1694. if local_reduction_init:
  1695. assert local_reduction_stores
  1696. code.splice(local_reduction_init)
  1697. for loop in loops:
  1698. gen_loop(loop)
  1699. if loops:
  1700. loop = loops[0]
  1701. if loop_nest.is_reduction_only() and loop.parallel:
  1702. if local_reduction_stores:
  1703. code.splice(local_reduction_stores)
  1704. worksharing.close()
  1705. if loop.is_reduction and not in_reduction:
  1706. code.splice(get_reduction_code_buffer(loops, "suffix"))
  1707. def gen_loop(loop: LoopLevel):
  1708. with contextlib.ExitStack() as stack:
  1709. loop_lines = loop.lines()
  1710. if loop_lines is None:
  1711. return
  1712. code.writelines(loop_lines)
  1713. stack.enter_context(code.indent())
  1714. # generate inner loops or loop body
  1715. if loop.inner:
  1716. gen_loops(loop.inner, loop.is_reduction)
  1717. else:
  1718. gen_loop_kernel(loop)
  1719. stack.enter_context(code.indent())
  1720. if loop_nest.root:
  1721. gen_loops(loop_nest.root)
  1722. else:
  1723. gen_kernel(loop_nest.kernel)
  1724. def codegen_loops(self, code, worksharing):
  1725. loop_nest = LoopNestWithSplit.build(self)
  1726. self.codegen_loops_impl(loop_nest, code, worksharing)
  1727. @property
  1728. def assert_function(self) -> str:
  1729. if V.graph.aot_mode:
  1730. # TODO: Using AOTI_TORCH_CHECK is causing performance drop for some models
  1731. # compared with JIT Inductor which uses TORCH_CHECK
  1732. return "AOTI_TORCH_CHECK"
  1733. else:
  1734. return "TORCH_CHECK"
  1735. def decide_parallel_depth(self, max_parallel_depth, threads):
  1736. assert self.call_ranges is not None
  1737. ranges = self.call_ranges[:max_parallel_depth]
  1738. seq = self.size_hint()
  1739. par = 1
  1740. depth = 0
  1741. for expr in ranges:
  1742. hint = V.graph.sizevars.size_hint(expr, fallback=8192)
  1743. if par >= 2 * threads or par == threads:
  1744. break
  1745. if seq // threads < config.cpp.min_chunk_size:
  1746. # not enough work
  1747. break
  1748. depth += 1
  1749. par *= hint
  1750. seq /= hint
  1751. # if we assume thread number is dynamic, make sure we
  1752. # have at least one parallel scope and let OMP runtime
  1753. # to manage the serial vs. parallel.
  1754. if config.cpp.dynamic_threads and depth == 0 and len(ranges) > 0:
  1755. depth = 1
  1756. return depth
  1757. @contextlib.contextmanager
  1758. def write_to_suffix(self):
  1759. prior = (self.loads, self.compute, self.stores, self.cse)
  1760. self.loads = IndentedBuffer()
  1761. self.compute = IndentedBuffer()
  1762. self.stores = IndentedBuffer()
  1763. self.cse = self.cse.clone()
  1764. yield
  1765. self.reduction_suffix.splice(self.loads)
  1766. self.reduction_suffix.splice(self.compute)
  1767. self.reduction_suffix.splice(self.stores)
  1768. (self.loads, self.compute, self.stores, self.cse) = prior
  1769. def create_cse_var(self, *args, **kwargs):
  1770. return CppCSEVariable(*args, **kwargs)
  1771. class CppVecKernel(CppKernel):
  1772. overrides = CppVecOverrides # type: ignore[assignment]
  1773. def __init__(
  1774. self,
  1775. args,
  1776. num_threads,
  1777. tiling_factor=0,
  1778. tiling_idx=-1,
  1779. tiling_dtype=torch.float,
  1780. ):
  1781. super().__init__(args, num_threads)
  1782. self.vec_isa = codecache.pick_vec_isa()
  1783. assert self.vec_isa
  1784. if tiling_factor == 0:
  1785. tiling_factor = self.vec_isa.nelements(dtype=tiling_dtype)
  1786. self.tiling_factor = tiling_factor
  1787. self.tiling_idx = tiling_idx
  1788. def _try_get_const_stride(self, index: sympy.Expr, itervar: sympy.Symbol):
  1789. if self.index_indirect_depends_on(index, itervar):
  1790. return None
  1791. for indirect_var in (
  1792. self.cse.varname_map[s.name] # type: ignore[attr-defined]
  1793. for s in index.free_symbols
  1794. if symbol_is_type(s, SymT.TMP)
  1795. ):
  1796. assert isinstance(indirect_var, CppCSEVariable)
  1797. if indirect_var.is_vec:
  1798. return None
  1799. stride = stride_at_vec_range(index, itervar, self.tiling_factor)
  1800. return stride if stride.is_number else None
  1801. def _get_num_vectors(self, dtype: torch.dtype) -> int:
  1802. num_vectors = math.ceil(
  1803. self.tiling_factor * dtype.itemsize * 8 / self.vec_isa.bit_width()
  1804. )
  1805. assert num_vectors >= 1
  1806. return num_vectors
  1807. def _get_vec_type(self, dtype: torch.dtype) -> str:
  1808. num_vectors = self._get_num_vectors(dtype)
  1809. if num_vectors == 1:
  1810. return f"at::vec::Vectorized<{DTYPE_TO_CPP[dtype]}>"
  1811. else:
  1812. return f"at::vec::VectorizedN<{DTYPE_TO_CPP[dtype]},{num_vectors}>"
  1813. def _get_mask_type(self, dtype: torch.dtype = torch.float) -> str:
  1814. if dtype == torch.bool:
  1815. return ""
  1816. num_vectors = self._get_num_vectors(dtype)
  1817. return f"at::vec::VecMask<{DTYPE_TO_CPP[dtype]},{num_vectors}>"
  1818. def _get_mask_cast(self, mask: CppCSEVariable, dtype: torch.dtype) -> str:
  1819. assert mask.dtype == torch.bool, repr(mask)
  1820. num_vectors = self._get_num_vectors(dtype)
  1821. return f"{mask}.template cast<{DTYPE_TO_CPP[dtype]},{num_vectors}>()"
  1822. def get_reduction_var_pattern(self, line: str):
  1823. return re.search("tmp_acc[0-9]+_vec", line)
  1824. def _get_vec_load_line(
  1825. self,
  1826. var: str,
  1827. index: sympy.Expr,
  1828. dtype: torch.dtype,
  1829. load_mask: Optional[CppCSEVariable] = None,
  1830. ):
  1831. """
  1832. Get a load line str that loads a vector from `var` at `index` of type `dtype`.
  1833. If `load_mask` is not None, we do a masked load accordingly.
  1834. Notes on the `dtype`:
  1835. 1. We always load `self.tiling_factor` number of elements regardless of the `dtype`.
  1836. It means we load half of the vector lanes for 16-bit data types and quarter of the
  1837. vector lanes for 8-bit data types.
  1838. 2. `torch.bool` and `torch.uint8` could mean masks and we load them as float mask vectors.
  1839. """
  1840. opt_ctx: OptimizationContext = get_current_node_opt_ctx()
  1841. assert opt_ctx is not None
  1842. cpp_type = DTYPE_TO_CPP[dtype]
  1843. num_vectors = self._get_num_vectors(dtype)
  1844. load_mask_str = None
  1845. if load_mask:
  1846. if not load_mask.is_vec:
  1847. # TODO: avoid hard-code torch.float
  1848. load_mask_str = f"{self._get_mask_type(torch.float)}::from({load_mask})"
  1849. else:
  1850. load_mask_str = f"{self._get_mask_cast(load_mask, torch.float)}"
  1851. loadbuf = f"{var} + {cexpr_index(index)}" if index != 0 else var
  1852. if dtype == torch.bool:
  1853. # TODO: should we consider load mask here?
  1854. line = f"{self._get_mask_type()}::from({loadbuf})"
  1855. else:
  1856. line = (
  1857. f"{load_mask_str}.template loadu<{cpp_type},{num_vectors}>({loadbuf})"
  1858. if load_mask_str
  1859. else f"{self._get_vec_type(dtype)}::loadu({loadbuf}, {self.tiling_factor})"
  1860. )
  1861. return line
  1862. def _load_or_store_non_contiguous(
  1863. self,
  1864. var: Optional[str],
  1865. index: sympy.Expr,
  1866. dtype: torch.dtype,
  1867. buffer: Optional[IndentedBuffer] = None,
  1868. store_value: Optional[Union[str, CppCSEVariable]] = None,
  1869. ) -> Optional[CppCSEVariable]:
  1870. """
  1871. Load or store a vector in a non-contiguous way. The vector is initialized from an array that is
  1872. filled in an inner loop over the tiling factor.
  1873. :param var: buffer to load from or store to, i.e. `var[transformed(index)]`. If None, we load the index
  1874. as index expression, i.e. `transformed(index)`.
  1875. :param index: index into the `var` or the index expression by its own if `var` is None.
  1876. The `index` could contain indirect indexing or the tiling itervar. When used in
  1877. the inner loop, the index is transformed as follows:
  1878. 1. the index is linearized along the tiling dim.
  1879. 2. the indirect indexing vector variables are transformed into arrays over the tiling dim.
  1880. :param dtype: data type of `var` or `index` if `var` is None.
  1881. :param buffer: the code buffer to write the generated code to. If None, we write to `self.loads`.
  1882. :param store_value: the value to store. If None, we load the vector.
  1883. :return: a CppCSEVariable that represents the loaded vector or None if it is a store.
  1884. """
  1885. assert not store_value or var is not None, "store var must be provided"
  1886. if buffer is None:
  1887. buffer = self.loads
  1888. def get_result_size(dtype: torch.dtype) -> int:
  1889. if dtype.itemsize < 4:
  1890. return self.tiling_factor * (4 // dtype.itemsize)
  1891. else:
  1892. return self.tiling_factor
  1893. def vec_to_array(vec_var: CppCSEVariable) -> CppCSEVariable:
  1894. assert vec_var.is_vec
  1895. code = BracesBuffer()
  1896. code.writeline("[&]")
  1897. with code.indent():
  1898. vec_dtype = vec_var.dtype
  1899. assert vec_dtype is not None
  1900. if vec_dtype == torch.bool:
  1901. vec_dtype = torch.float
  1902. result_size = get_result_size(vec_dtype)
  1903. code.writeline(
  1904. f"__at_align__ std::array<{DTYPE_TO_CPP[vec_dtype]}, {result_size}> tmpbuf;"
  1905. )
  1906. line = f"{vec_var}.store(tmpbuf.data());"
  1907. code.writeline(line)
  1908. code.writeline("return tmpbuf;")
  1909. code.writeline("()")
  1910. csevar = self.cse.generate(buffer, code)
  1911. assert isinstance(csevar, CppCSEVariable)
  1912. return csevar
  1913. opt_ctx: OptimizationContext = get_current_node_opt_ctx()
  1914. assert opt_ctx is not None
  1915. code = BracesBuffer()
  1916. code.writeline("[&]")
  1917. with code.indent():
  1918. result_size = get_result_size(dtype)
  1919. result_declare = (
  1920. f"__at_align__ std::array<{DTYPE_TO_CPP[dtype]}, {result_size}> tmpbuf;"
  1921. )
  1922. code.writeline(result_declare)
  1923. if store_value:
  1924. code.writeline(f"{store_value}.store(tmpbuf.data());")
  1925. itervar_inner = sympy_index_symbol(
  1926. f"{self.itervars[self.tiling_idx]}_inner"
  1927. )
  1928. replacements = {}
  1929. for indirect_var in (
  1930. self.cse.varname_map[s.name] # type: ignore[attr-defined]
  1931. for s in index.free_symbols
  1932. if symbol_is_type(s, SymT.TMP)
  1933. ):
  1934. assert isinstance(indirect_var, CppCSEVariable)
  1935. if indirect_var.is_vec:
  1936. array_var = vec_to_array(indirect_var)
  1937. replacements[indirect_var] = f"{array_var}[{itervar_inner}]"
  1938. index = self.scale_index_with_offset(
  1939. index, itervar_idx=self.tiling_idx, offset=itervar_inner
  1940. )
  1941. load_mask = None
  1942. if self._load_mask is not None:
  1943. assert not store_value, "unexpected store with load mask"
  1944. assert isinstance(self._load_mask, CppCSEVariable), self._load_mask
  1945. if self._load_mask.is_vec:
  1946. load_mask = f"{self._load_mask}.is_masked({itervar_inner})"
  1947. else:
  1948. load_mask = f"{self._load_mask} != 0"
  1949. if codecache.is_gcc():
  1950. code.writeline(f"#pragma GCC unroll {self.tiling_factor}")
  1951. else:
  1952. code.writeline(f"#pragma unroll {self.tiling_factor}")
  1953. code.writeline(
  1954. f"for (long {itervar_inner} = 0; {itervar_inner} < {self.tiling_factor}; {itervar_inner}++)"
  1955. )
  1956. with code.indent(), contextlib.ExitStack() as stack:
  1957. index_c = cexpr_index(index)
  1958. for indirect_var in replacements:
  1959. index_c = re.sub(
  1960. r"\b" + f"{indirect_var}" + r"\b",
  1961. replacements[indirect_var],
  1962. index_c,
  1963. )
  1964. rhs = f"{var}[{index_c}]" if var is not None else f"{index_c}"
  1965. if load_mask:
  1966. code.writeline(f"if ({load_mask})")
  1967. stack.enter_context(code.indent())
  1968. if store_value:
  1969. code.writeline(f"{rhs} = tmpbuf[{itervar_inner}];")
  1970. else:
  1971. code.writeline(f"tmpbuf[{itervar_inner}] = {rhs};")
  1972. if not store_value:
  1973. load_line = self._get_vec_load_line("tmpbuf.data()", 0, dtype) # type: ignore[arg-type]
  1974. code.writeline(f"return {load_line};")
  1975. code.writeline("()")
  1976. if store_value:
  1977. code.writeline(";")
  1978. buffer.splice(code)
  1979. return None
  1980. else:
  1981. csevar = self.cse.generate(buffer, code)
  1982. assert isinstance(csevar, CppCSEVariable)
  1983. csevar.is_vec = True
  1984. return csevar
  1985. def load(self, name: str, index: sympy.Expr):
  1986. opt_ctx: OptimizationContext = get_current_node_opt_ctx()
  1987. var = self.args.input(name)
  1988. index = self.rename_indexing(index)
  1989. dtype = V.graph.get_dtype(name)
  1990. tiling_var = self.itervars[self.tiling_idx]
  1991. stride = self._try_get_const_stride(index, tiling_var)
  1992. if stride == 0:
  1993. # load scalar and lazily broadcast it on demand
  1994. return super().load(name, index)
  1995. elif stride == 1:
  1996. # load contiguously
  1997. line = self._get_vec_load_line(var, index, dtype, self._load_mask)
  1998. csevar = self.cse.generate(self.loads, line) # type: ignore[assignment]
  1999. else:
  2000. csevar = self._load_or_store_non_contiguous(var, index, dtype) # type: ignore[assignment]
  2001. assert isinstance(csevar, CppCSEVariable)
  2002. csevar.update_on_args("load", (name, index), {})
  2003. csevar.is_vec = True
  2004. return csevar
  2005. def _get_store_line(
  2006. self,
  2007. value: Union[str, CppCSEVariable],
  2008. var: str,
  2009. index: sympy.Expr,
  2010. dtype: torch.dtype,
  2011. ):
  2012. """
  2013. Get a store line buffer that stores `value` into `var` at `index` of `dtype`. It handles
  2014. both contiguous and non-contiguous store cases.
  2015. :param value: Vectorized type templaterized on `dtype`.
  2016. :param var: buffer to store into.
  2017. :index: index into the `var`.
  2018. """
  2019. # when value's type is str (e.g., welford reduction), caller should make sure
  2020. # it is a vector
  2021. assert isinstance(value, str) or (
  2022. isinstance(value, CppCSEVariable) and value.is_vec
  2023. ), value
  2024. tiling_var = self.itervars[self.tiling_idx]
  2025. var_expr = f"{var} + {cexpr_index(index)}"
  2026. stride = self._try_get_const_stride(index, tiling_var)
  2027. code = IndentedBuffer()
  2028. if stride == 1:
  2029. if dtype == torch.float:
  2030. code.writeline(f"{value}.store({var_expr});")
  2031. else:
  2032. code.writeline(f"{value}.store({var_expr}, {self.tiling_factor});")
  2033. else:
  2034. self._load_or_store_non_contiguous(
  2035. var, index, dtype, buffer=code, store_value=value
  2036. )
  2037. return code
  2038. def store(self, name, index, value, mode=None):
  2039. assert "buf" in name
  2040. assert mode is None
  2041. assert isinstance(value, CppCSEVariable), value
  2042. if not value.is_vec:
  2043. # this happens when we store a scalar into a vectorized buffer like "fill"
  2044. value = self.broadcast(value)
  2045. opt_ctx: OptimizationContext = get_current_node_opt_ctx()
  2046. var = self.args.output(name)
  2047. self.cache_fp32_cse_var_before_lowp_store(value)
  2048. index = self.rename_indexing(index)
  2049. code = self._get_store_line(value, var, index, V.graph.get_dtype(name))
  2050. self.stores.splice(code.map(lambda x: DeferredLine(name, x)))
  2051. def reduction(self, dtype, src_dtype, reduction_type, value):
  2052. assert reduction_type in {
  2053. "max",
  2054. "min",
  2055. "sum",
  2056. "prod",
  2057. "xor_sum",
  2058. "welford_reduce",
  2059. "welford_combine",
  2060. }
  2061. assert dtype == src_dtype
  2062. assert dtype in [torch.float, torch.int64]
  2063. assert isinstance(value, CppCSEVariable), value
  2064. if not value.is_vec:
  2065. value = self.broadcast(value)
  2066. reduction_key = src_dtype, reduction_type, value
  2067. if reduction_key in self.reduction_cse.reduction_cache:
  2068. return self.reduction_cse.reduction_cache[reduction_key]
  2069. vec_ns = "at::vec"
  2070. vec = f"{vec_ns}::Vectorized<{DTYPE_TO_CPP[dtype]}>"
  2071. acc_type = reduction_acc_type(reduction_type, dtype)
  2072. acc_type_vec = self.reduction_acc_type_vec(reduction_type, dtype)
  2073. acc = self.reduction_cse.generate(
  2074. self.loads, f"reduction {reduction_key}", write=False
  2075. )
  2076. acc_vec = f"{acc}_vec"
  2077. self.is_reduction = True
  2078. self.reduction_prefix.writeline(
  2079. f"{acc_type} {acc} = {reduction_init(reduction_type, dtype)};"
  2080. )
  2081. self.reduction_prefix.writeline(
  2082. f"{acc_type_vec} {acc_vec} = {self.reduction_init_vec(reduction_type, dtype)};"
  2083. )
  2084. # save the reciprocal of weights for welford reduce if using static shape
  2085. reduction_size = functools.reduce(
  2086. lambda x, y: x * y, self.ranges[self.reduction_depth :]
  2087. )
  2088. if reduction_type == "welford_reduce":
  2089. reduction_factor = (
  2090. self.tiling_factor if self.tiling_idx >= self.reduction_depth else 1
  2091. )
  2092. self.weight_recp_vec_range = FloorDiv(reduction_size, reduction_factor)
  2093. self.non_parallel_reduction_prefix.writeline(
  2094. self.welford_weight_reciprocal_vec(dtype, None)
  2095. )
  2096. self.stores.writeline(
  2097. f"{acc_vec} = {self.reduction_combine_vec(reduction_type, acc_vec, value, True)};"
  2098. )
  2099. else:
  2100. self.stores.writeline(
  2101. f"{acc_vec} = {self.reduction_combine_vec(reduction_type, acc_vec, value)};"
  2102. )
  2103. self._gen_parallel_reduction_buffers(
  2104. acc,
  2105. acc_type,
  2106. reduction_type,
  2107. dtype,
  2108. )
  2109. self._gen_parallel_reduction_buffers(
  2110. acc_vec,
  2111. acc_type_vec,
  2112. reduction_type,
  2113. dtype,
  2114. reduction_combine_fn=self.reduction_combine_vec,
  2115. reduction_init_fn=self.reduction_init_vec,
  2116. welford_weight_reciprocal_vec_fn=self.welford_weight_reciprocal_vec,
  2117. )
  2118. tmpvar: Union[str, CSEVariable]
  2119. if self.tiling_idx >= self.reduction_depth:
  2120. # Horizontal reduction
  2121. if is_welford_reduction(reduction_type):
  2122. assert (
  2123. self._get_num_vectors(dtype) == 1
  2124. ), "Welford reduction does not support VectorizedN (N>1)"
  2125. next_value = f"welford_vec_reduce_all({acc_vec})"
  2126. else:
  2127. reduce_all_body = (
  2128. "{ return "
  2129. + self.reduction_combine_vec(reduction_type, "x", "y")
  2130. + "; }"
  2131. )
  2132. vec = f"at::vec::Vectorized<{DTYPE_TO_CPP[dtype]}>"
  2133. vec_reduce_all_func = f"at::vec::vec_reduce_all<{DTYPE_TO_CPP[dtype]}>"
  2134. next_value = f"{vec_reduce_all_func}([]({vec}& x, {vec}& y) {reduce_all_body}, {acc_vec})"
  2135. self.reduction_suffix.writeline(
  2136. f"{acc} = {reduction_combine(reduction_type, acc, next_value)};"
  2137. )
  2138. tmpvar = acc
  2139. else:
  2140. tmpvar = acc_vec
  2141. result = reduction_project(reduction_type, tmpvar)
  2142. self.reduction_cse.reduction_cache[reduction_key] = result
  2143. return result
  2144. def store_reduction(self, name, index, value):
  2145. index = self.rename_indexing(index)
  2146. var = self.args.output(name)
  2147. out_dtype = V.graph.get_dtype(name)
  2148. dtype = torch.float if out_dtype.is_floating_point else torch.int64
  2149. code = IndentedBuffer()
  2150. if self.tiling_idx >= self.reduction_depth:
  2151. # Horizontal reduction
  2152. code.writeline(
  2153. f"{var}[{cexpr_index(index)}] = static_cast<{DTYPE_TO_CPP[out_dtype]}>({value});"
  2154. )
  2155. else:
  2156. # Vertical reduction
  2157. if out_dtype != dtype:
  2158. converted_value = f"{DTYPE_TO_CPP[out_dtype]}_{value}"
  2159. code.writeline(
  2160. f"auto {converted_value} = at::vec::convert<{DTYPE_TO_CPP[out_dtype]}>({value});"
  2161. )
  2162. value = converted_value
  2163. code.splice(self._get_store_line(value, var, index, out_dtype))
  2164. self.reduction_suffix.splice(code.map(lambda x: DeferredLine(name, x)))
  2165. def broadcast(self, scalar_var: CppCSEVariable) -> CppCSEVariable:
  2166. assert not scalar_var.is_vec
  2167. if scalar_var.dtype == torch.bool:
  2168. vec_var = self.cse.generate(
  2169. self.compute, f"{self._get_mask_type()}::from({scalar_var.name})"
  2170. )
  2171. else:
  2172. assert scalar_var.dtype is not None
  2173. vec_var = self.cse.generate(
  2174. self.compute,
  2175. f"{self._get_vec_type(scalar_var.dtype)}({scalar_var.name})",
  2176. )
  2177. assert isinstance(vec_var, CppCSEVariable)
  2178. vec_var.dtype = scalar_var.dtype
  2179. vec_var.dependent_itervars = scalar_var.dependent_itervars
  2180. vec_var.is_vec = True
  2181. return vec_var
  2182. def arange(self, index: CppCSEVariable, stride: sympy.Symbol) -> CppCSEVariable:
  2183. assert not index.is_vec
  2184. assert index.dtype is not None
  2185. csevar = self.cse.generate(
  2186. self.compute,
  2187. f"{self._get_vec_type(index.dtype)}::arange({index}, {stride})",
  2188. )
  2189. assert isinstance(csevar, CppCSEVariable)
  2190. csevar.dtype = index.dtype
  2191. csevar.is_vec = True
  2192. return csevar
  2193. def reduction_init_vec(self, reduction_type, dtype):
  2194. scalar_type = DTYPE_TO_COMPUTATION_DTYPE[dtype]
  2195. vec_type = self._get_vec_type(scalar_type)
  2196. if is_welford_reduction(reduction_type):
  2197. return f"Welford<{vec_type}>()"
  2198. scalar_init = reduction_init(reduction_type, dtype)
  2199. return f"{vec_type}({scalar_init})"
  2200. def reduction_acc_type_vec(self, reduction_type, dtype):
  2201. assert reduction_type not in {"argmin", "argmax"}
  2202. scalar_type = DTYPE_TO_COMPUTATION_DTYPE[dtype]
  2203. vec_type = self._get_vec_type(scalar_type)
  2204. if is_welford_reduction(reduction_type):
  2205. return f"Welford<{vec_type}>"
  2206. return vec_type
  2207. def welford_weight_reciprocal_vec(self, dtype, num_threads=None):
  2208. vec_num_range_thread = (
  2209. CeilDiv(self.weight_recp_vec_range, num_threads)
  2210. if num_threads
  2211. else self.weight_recp_vec_range
  2212. )
  2213. vec_num_range_thread_expr = cexpr_index(vec_num_range_thread)
  2214. return f"static WeightRecp<{self._get_vec_type(dtype)}> weight_recps({vec_num_range_thread_expr});"
  2215. def reduction_combine_vec(
  2216. self, reduction_type, var, next_value, use_weight_recps=False
  2217. ):
  2218. if reduction_type == "max":
  2219. return f"at::vec::maximum({var}, {next_value})"
  2220. elif reduction_type == "min":
  2221. return f"at::vec::minimum({var}, {next_value})"
  2222. elif reduction_type == "sum":
  2223. return f"{var} + {next_value}"
  2224. elif reduction_type == "prod":
  2225. return f"{var} * {next_value}"
  2226. elif reduction_type == "xor_sum":
  2227. return f"{var} ^ {next_value}"
  2228. elif reduction_type == "welford_reduce":
  2229. if use_weight_recps:
  2230. return f"welford_combine({var}, {next_value}, &weight_recps)"
  2231. else:
  2232. return f"welford_combine({var}, {next_value})"
  2233. elif reduction_type == "welford_combine":
  2234. if isinstance(next_value, tuple):
  2235. # When reading a value from Inductor IR we have a tuple of variable names
  2236. mean, m2, weight = next_value
  2237. else:
  2238. # When combining intermediate accumulators we have a Welford<T> struct
  2239. mean, m2, weight = reduction_project(reduction_type, next_value)
  2240. return f"welford_combine({var}, {{{mean}, {m2}, {weight}}})"
  2241. else:
  2242. raise NotImplementedError
  2243. def indirect_assert(self, var, lower, upper, mask=None):
  2244. assert not mask, "do not support mask in indirect_indexing assertion"
  2245. assert isinstance(var, CppCSEVariable)
  2246. assert var.dtype is not None
  2247. if not var.is_vec:
  2248. return super().indirect_assert(var, lower, upper, mask)
  2249. lower_scalar = lower
  2250. upper_scalar = upper
  2251. if lower:
  2252. lower = f"{self._get_vec_type(var.dtype)}({lower})"
  2253. if upper:
  2254. upper = f"{self._get_vec_type(var.dtype)}({upper})"
  2255. if lower and upper:
  2256. cond = f"({lower} <= {var}) & ({var} < {upper})"
  2257. cond_print = f"{lower_scalar} <= {var} < {upper_scalar}"
  2258. elif lower:
  2259. cond = f"{lower} <= {var}"
  2260. cond_print = f"{lower_scalar} <= {var}"
  2261. else:
  2262. assert upper
  2263. cond = f"{var} < {upper}"
  2264. cond_print = f"{var} < {upper_scalar}"
  2265. cond = f"({self._get_mask_type(var.dtype)}({cond})).all_masked()"
  2266. return f'{self.assert_function}({cond}, "index out of bounds: {cond_print}")'
  2267. class CppTile2DKernel(CppVecKernel):
  2268. """
  2269. A vector kernel that handles the 2d tiles with the tile size defined in `tiling_factor` on
  2270. the inner-most loop level and one of the outer loop level (`outer_tiling_idx`). When the data
  2271. tile is accessed in a contiguous way from the outer loop axis, a transposition is applied on the
  2272. tile to make the access contiguous from the inner-most loop axis. Then, the same vectorization
  2273. logic from its parent `CppVecKernel` is leveraged for load/store/compute. The transposed tile load
  2274. and store are generated into kernel.preloads and kernel.poststores buffers.
  2275. The loop structure looks like below:
  2276. for ...
  2277. for i_outer ...
  2278. for ...
  2279. for inner_most ...
  2280. // generated by CppTile2DKernel
  2281. float tmp0[16*16]; at::vec::transpose_mxn<...>(tmp0, in_ptr0 + ..., ...); // into kernel.preloads
  2282. float tmp1[16*16]; // into kernel.preloads
  2283. for i_inner ... { // the kernel inner loop
  2284. vectorized loads/compute/stores (e.g., load tmp0, store tmp1) // into kernel.loads/compute/stores
  2285. }
  2286. at::vec::transpose_mxn(out_ptr0 + ..., tmp1, ...) // into kernel.poststores
  2287. for inner_most ... (tail)
  2288. // generated by CppVecKernel
  2289. ...
  2290. for i_outer ... (tail)
  2291. for ...
  2292. for ...
  2293. // generated by CppKernel
  2294. ...
  2295. """
  2296. overrides = CppTile2DOverrides # type: ignore[assignment]
  2297. def __init__(self, args, num_threads, tiling_factor, tiling_indices, tiling_dtype):
  2298. super().__init__(
  2299. args, num_threads, tiling_factor, tiling_indices[1], tiling_dtype
  2300. )
  2301. self.tiling_indices = tiling_indices
  2302. def inner_itervar(self):
  2303. return sympy_index_symbol(f"{self.itervars[self.outer_idx]}_inner")
  2304. def need_vec_transpose(self, index):
  2305. outer_var = self.itervars[self.outer_idx]
  2306. inner_var = self.itervars[self.tiling_idx]
  2307. outer_stride = stride_at_vec_range(index, outer_var, self.tiling_factor)
  2308. inner_stride = stride_at_vec_range(index, inner_var, self.tiling_factor)
  2309. return (
  2310. self._load_mask is None # TODO: support transposition with mask
  2311. and outer_stride == 1
  2312. and index.has(inner_var)
  2313. and not inner_stride.has(inner_var)
  2314. and not inner_stride.has(outer_var)
  2315. )
  2316. def gen_transposed_tile_load_store(self, name, var, index, is_store):
  2317. # transposed tile load/store outside the kernel inner loop
  2318. dtype = V.graph.get_dtype(name)
  2319. factor = self.tiling_factor
  2320. src = f"{var} + {cexpr_index(index)}"
  2321. dst = "__place_holder__"
  2322. ld_src = f"{cexpr_index(stride_at_vec_range(index, self.itervars[self.tiling_idx], self.tiling_factor))}"
  2323. ld_dst = f"{factor}"
  2324. if is_store:
  2325. src, dst = dst, src
  2326. ld_src, ld_dst = ld_dst, ld_src
  2327. need_define = True
  2328. load_or_store = f"at::vec::transpose_mxn<{DTYPE_TO_CPP[dtype]},{factor},{factor}>({src}, {ld_src}, {dst}, {ld_dst});"
  2329. if is_store:
  2330. tile_var = self.cse.newvar()
  2331. elif load_or_store not in self.cse.cache:
  2332. tile_var = self.cse.generate(self.preloads, load_or_store, write=False)
  2333. else:
  2334. need_define = False
  2335. tile_var = self.cse.cache[load_or_store]
  2336. if need_define:
  2337. define_line = f"{DTYPE_TO_CPP[dtype]} {tile_var}[{factor}*{factor}] __attribute__ ((aligned ({factor})));"
  2338. self.preloads.writeline(define_line)
  2339. load_or_store = load_or_store.replace("__place_holder__", str(tile_var))
  2340. if is_store:
  2341. self.poststores.writeline(DeferredLine(name, load_or_store))
  2342. else:
  2343. self.preloads.writeline(load_or_store)
  2344. return tile_var
  2345. def load(self, name: str, index: sympy.Expr):
  2346. opt_ctx: OptimizationContext = get_current_node_opt_ctx()
  2347. var = self.args.input(name)
  2348. index = self.rename_indexing(index)
  2349. inner = self.inner_itervar()
  2350. if self.need_vec_transpose(index):
  2351. tile_var = self.gen_transposed_tile_load_store(
  2352. name, var, index, is_store=False
  2353. )
  2354. # vector load inside the kernel inner loop
  2355. loadbuf = f"{tile_var} + {cexpr_index(inner * self.tiling_factor)}"
  2356. dtype = V.graph.get_dtype(name)
  2357. line = self._get_vec_load_line(loadbuf, 0, dtype) # type: ignore[arg-type]
  2358. csevar = self.cse.generate(self.loads, line)
  2359. csevar.update_on_args("load", (name, index), {})
  2360. assert isinstance(csevar, CppCSEVariable)
  2361. csevar.is_vec = True
  2362. return csevar
  2363. else:
  2364. new_index = self.transform_indexing(index)
  2365. return super().load(name, new_index)
  2366. def store(self, name, index, value, mode=None):
  2367. assert "buf" in name
  2368. opt_ctx: OptimizationContext = get_current_node_opt_ctx()
  2369. var = self.args.output(name)
  2370. inner = self.inner_itervar()
  2371. index = self.rename_indexing(index)
  2372. assert mode is None
  2373. if self.need_vec_transpose(index):
  2374. tile_var = self.gen_transposed_tile_load_store(
  2375. name, var, index, is_store=True
  2376. )
  2377. # vector store inside the kernel inner loop
  2378. storebuf = f"{tile_var} + {cexpr_index(inner * self.tiling_factor)}"
  2379. if V.graph.get_dtype(name) in DTYPE_LOWP_FP:
  2380. line = f"{value}.store({storebuf}, {self.tiling_factor});"
  2381. elif V.graph.get_dtype(name) in (torch.uint8, torch.int8):
  2382. line = f"{value}.store({storebuf}, {self.tiling_factor});"
  2383. else:
  2384. line = f"{value}.store({storebuf});"
  2385. self.stores.writeline(DeferredLine(name, line))
  2386. else:
  2387. new_index = self.transform_indexing(index)
  2388. super().store(name, new_index, value, mode)
  2389. def codegen_inner_loops(self, code):
  2390. inner = self.inner_itervar()
  2391. code.writeline(
  2392. f"for (long {inner} = 0; {inner} < {self.tiling_factor}; {inner}++)"
  2393. )
  2394. def set_ranges(self, group, reduction_group):
  2395. vars = super().set_ranges(group, reduction_group)
  2396. # do vertical reduction as the tail loop
  2397. self.outer_idx, self.tiling_idx = (
  2398. self.tiling_indices
  2399. if self.tiling_indices[1] < self.reduction_depth
  2400. else reversed(self.tiling_indices)
  2401. )
  2402. return vars
  2403. def transform_indexing(self, index: sympy.Expr) -> sympy.Expr:
  2404. return self.scale_index_with_offset(
  2405. index,
  2406. itervar_idx=self.outer_idx,
  2407. offset=self.inner_itervar(),
  2408. )
  2409. class CppVecKernelChecker(CppVecKernel):
  2410. def __init__(self, args, num_threads, tiling_factor, tiling_idx=-1):
  2411. super().__init__(args, num_threads, tiling_factor, tiling_idx)
  2412. # Since this kernel is only for checker but does not generate any
  2413. # code, so we need to decrease the kernel count.
  2414. metrics.generated_kernel_count -= 1
  2415. # Used to record the graph wrapper code as the wrapper_code status could be
  2416. # changed during graph run.
  2417. self._orig_wrapper_code = None
  2418. self.simd_vec = True
  2419. self.fast_vec_list = []
  2420. for k, v in CppVecOverrides.__dict__.items():
  2421. if isinstance(v, staticmethod):
  2422. self.fast_vec_list.append(k)
  2423. self.exit_stack = contextlib.ExitStack()
  2424. # Cache all the load result
  2425. self.supported_dtypes: List[torch.dtype] = [
  2426. torch.float,
  2427. torch.bfloat16,
  2428. torch.float16,
  2429. torch.bool,
  2430. torch.uint8,
  2431. torch.int8,
  2432. torch.int32,
  2433. torch.int64,
  2434. ]
  2435. def disable_vec(self, msg=None):
  2436. if schedule_log.isEnabledFor(logging.DEBUG):
  2437. schedule_log.debug("Disabled vectorization: %s", msg)
  2438. self.simd_vec = False
  2439. def load(self, name: str, index: sympy.Expr):
  2440. with RecordOptimizationContext(__name__) as node_ctx:
  2441. load_dtype = V.graph.get_dtype(name)
  2442. opt_ctx: OptimizationContext = node_ctx.get_opt_ctx()
  2443. assert opt_ctx
  2444. opt_ctx.dtype = load_dtype
  2445. var = self.cse.newvar()
  2446. if len(self.itervars) == 0:
  2447. self.disable_vec("not a loop")
  2448. return var
  2449. if load_dtype not in self.supported_dtypes and (
  2450. index.has(self.itervars[self.tiling_idx])
  2451. or free_symbol_is_type(index, SymT.TMP)
  2452. ):
  2453. self.disable_vec(f"{load_dtype} not supported by load")
  2454. return var
  2455. return var
  2456. def store(self, name, index, value, mode=None):
  2457. with RecordOptimizationContext(__name__) as node_ctx:
  2458. if len(self.itervars) == 0:
  2459. self.disable_vec("not a loop")
  2460. return self.simd_vec
  2461. store_dtype = V.graph.get_dtype(name)
  2462. opt_ctx: OptimizationContext = node_ctx.get_opt_ctx()
  2463. assert opt_ctx
  2464. opt_ctx.dtype = store_dtype
  2465. if store_dtype not in self.supported_dtypes:
  2466. self.disable_vec(f"{store_dtype} not supported by store")
  2467. return self.simd_vec
  2468. assert "buf" in name
  2469. index = self.rename_indexing(index)
  2470. if mode:
  2471. self.disable_vec(f"store mode: {mode}")
  2472. return self.simd_vec
  2473. return self.simd_vec
  2474. def reduction(self, dtype, src_dtype, reduction_type, value):
  2475. if not (
  2476. (dtype == torch.float and src_dtype == torch.float)
  2477. or (dtype == torch.int64 and src_dtype == torch.int64)
  2478. and reduction_type in VECTORIZABLE_RTYPES
  2479. ):
  2480. self.disable_vec(
  2481. f"reduction: dtype {dtype}, src_dtype {src_dtype}, reduction_type {reduction_type}"
  2482. )
  2483. if is_welford_reduction(reduction_type):
  2484. return tuple([self.simd_vec] * 3)
  2485. return self.simd_vec
  2486. def check_bounds(
  2487. self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
  2488. ):
  2489. return self.simd_vec
  2490. def store_reduction(self, name, index, value):
  2491. return self.simd_vec
  2492. def __exit__(self, exc_type, exc_val, exc_tb):
  2493. assert self._orig_wrapper_code is not None
  2494. # Restore the wrapper_code
  2495. V.graph.wrapper_code = self._orig_wrapper_code
  2496. self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
  2497. def __enter__(self):
  2498. # Record the graph wrapper code. The wrapper_code status could be
  2499. # changed during graph run. Regarding this checker, we also need to
  2500. # run the graph but we don't expect to change any status that would
  2501. # impact the code generation. Hence, we record the graph wrapper code
  2502. # and replace it with a dummy wrapper_code and then restore to the
  2503. # original one as long as the checker is finished.
  2504. self._orig_wrapper_code = V.graph.wrapper_code
  2505. V.graph.wrapper_code = WrapperCodeGen()
  2506. parent_handler = V.MockHandler()
  2507. class VecCheckerProxy:
  2508. @staticmethod
  2509. def __getattr__(name): # type: ignore[misc]
  2510. def inner(*args, **kwargs):
  2511. if name not in self.fast_vec_list:
  2512. self.disable_vec(f"op: {name}")
  2513. parent_val = getattr(parent_handler, name)(*args, **kwargs)
  2514. return pytree.tree_map(lambda _: self.simd_vec, parent_val)
  2515. return inner
  2516. @staticmethod
  2517. def load(name: str, index: sympy.Expr):
  2518. return self.load(name, index)
  2519. @staticmethod
  2520. def store(name, index, value, mode=None):
  2521. return self.store(name, index, value, mode=mode)
  2522. @staticmethod
  2523. def reduction(dtype, src_dtype, reduction_type, value):
  2524. return self.reduction(dtype, src_dtype, reduction_type, value)
  2525. @staticmethod
  2526. def store_reduction(name, index, value):
  2527. return self.store_reduction(name, index, value)
  2528. @staticmethod
  2529. def check_bounds(
  2530. expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
  2531. ):
  2532. return self.check_bounds(expr, size, lower, upper)
  2533. @staticmethod
  2534. def constant(val, dtype):
  2535. with RecordOptimizationContext(__name__) as node_ctx:
  2536. opt_ctx: OptimizationContext = node_ctx.get_opt_ctx()
  2537. assert opt_ctx
  2538. # VecKernel override dtype for constant
  2539. # Vectorization only support int32/fp32 now
  2540. # So if dtype = int64/fp64, we will cast it to int32/fp32 if possible
  2541. i32_iinfo = torch.iinfo(torch.int32)
  2542. if (
  2543. dtype == torch.int64
  2544. and val <= i32_iinfo.max
  2545. and val >= i32_iinfo.min
  2546. and all(
  2547. user.target in BIN_CMP_OPS
  2548. for user in node_ctx.current_node.users
  2549. )
  2550. ):
  2551. opt_ctx.dtype = torch.int32
  2552. f32_iinfo = torch.finfo(torch.float32)
  2553. if dtype == torch.double:
  2554. if (
  2555. (val <= f32_iinfo.max and val >= f32_iinfo.min)
  2556. or (val == torch.inf)
  2557. or (val == -torch.inf)
  2558. ):
  2559. opt_ctx.dtype = torch.float32
  2560. if opt_ctx.dtype not in self.supported_dtypes:
  2561. self.disable_vec(f"constant dtype: {opt_ctx.dtype}")
  2562. return val
  2563. @staticmethod
  2564. def index_expr(expr, dtype):
  2565. assert len(self.ranges) == len(self.itervars)
  2566. def can_use_int32():
  2567. free_symbols = list(expr.free_symbols)
  2568. sizes = {
  2569. k: v
  2570. for k, v in zip(self.itervars, self.ranges)
  2571. if k in free_symbols
  2572. }
  2573. # Trivial case: Range empty
  2574. if any(v == 0 for v in sizes.values()):
  2575. return True
  2576. vars_ranges = {
  2577. k: ValueRanges(0, v - 1)
  2578. for k, v in sizes.items()
  2579. if not isinstance(v, sympy.Expr) or v.is_number
  2580. }
  2581. if not vars_ranges or len(vars_ranges) != len(free_symbols):
  2582. i32_iinfo = torch.iinfo(torch.int32)
  2583. return (
  2584. expr.is_number
  2585. and expr <= i32_iinfo.max
  2586. and expr >= i32_iinfo.min
  2587. )
  2588. expr_ranges = bound_sympy(expr, vars_ranges)
  2589. if math.isinf(expr_ranges.lower) or math.isinf(expr_ranges.upper): # type: ignore[arg-type]
  2590. return False
  2591. # If something takes the values 0..7, we will compare in the loop
  2592. # x < 8. As such, for the loop not to overflow in the last iteration, we want
  2593. # to check that expr_ranges.upper + 1 is representable as well
  2594. return range_expressable_in_32_bits(
  2595. ValueRanges(
  2596. int(expr_ranges.lower), int(expr_ranges.upper) + 1 # type: ignore[arg-type]
  2597. )
  2598. )
  2599. with RecordOptimizationContext(__name__) as node_ctx:
  2600. assert len(self.ranges) == len(self.itervars)
  2601. opt_ctx: OptimizationContext = node_ctx.get_opt_ctx()
  2602. assert opt_ctx
  2603. if (
  2604. dtype == torch.int64
  2605. and can_use_int32()
  2606. and all(
  2607. user.target in BIN_CMP_OPS
  2608. for user in node_ctx.current_node.users
  2609. )
  2610. ):
  2611. opt_ctx.dtype = torch.int32
  2612. else:
  2613. self.disable_vec(f"index_expr: {expr}, dtype {dtype}")
  2614. tmp_var = self.cse.newvar()
  2615. return tmp_var
  2616. @staticmethod
  2617. def indirect_indexing(index_var, size, check=True):
  2618. return sympy_index_symbol(str(index_var))
  2619. @staticmethod
  2620. def masked(mask, body, other):
  2621. body()
  2622. return self.cse.newvar()
  2623. @staticmethod
  2624. def to_dtype(x, dtype, src_dtype=None):
  2625. if dtype not in self.supported_dtypes:
  2626. self.disable_vec(f"to_dtype: {dtype}")
  2627. return x
  2628. self.exit_stack.enter_context(V.set_ops_handler(VecCheckerProxy()))
  2629. self.exit_stack.enter_context(V.set_kernel_handler(self))
  2630. return self
  2631. class CppKernelProxy(CppKernel):
  2632. def __init__(self, kernel_group):
  2633. super().__init__(kernel_group.args, kernel_group.ws.num_threads)
  2634. self.kernel_group = kernel_group
  2635. self.loop_nest = None
  2636. self.call_ranges = None
  2637. self.picked_vec_isa: codecache.VecISA = codecache.pick_vec_isa()
  2638. def data_type_propagation(self, nodes):
  2639. for _node in nodes:
  2640. assert isinstance(_node, SchedulerNode)
  2641. DataTypePropagation.propagate_scheduler_node(_node)
  2642. # Check if all the nodes of a given fx graph can support BF16/FP16
  2643. def is_lowp_fp_scheduler(self, scheduler_node: SchedulerNode):
  2644. if not isinstance(scheduler_node._body, ir.LoopBody):
  2645. return True
  2646. _lowp_fp_type: Optional[torch.dtype] = None
  2647. # Propagate the dtype to check if all the fx node is bf16/fp16
  2648. DataTypePropagation.propagate_scheduler_node(scheduler_node)
  2649. sub_blocks = [scheduler_node._body.root_block] + list(
  2650. scheduler_node._body.subblocks.values()
  2651. )
  2652. for sub_block in sub_blocks:
  2653. for _node in sub_block.graph.nodes:
  2654. # TODO(Eikan): Regarding get_index and index_expr, we should conclude the
  2655. # the data type as well.
  2656. if _node.op == "placeholder" or _node.target in (
  2657. "get_index",
  2658. "index_expr",
  2659. ):
  2660. continue
  2661. # Fast path if all operations can support bf16/fp16 without converting to fp32
  2662. if _node.target not in [
  2663. "load",
  2664. "store",
  2665. "abs",
  2666. "neg",
  2667. "output",
  2668. ]:
  2669. return False
  2670. if hasattr(_node, "meta") and _node.meta:
  2671. assert OptimizationContext.key in _node.meta
  2672. opt_ctx: OptimizationContext = _node.meta[OptimizationContext.key]
  2673. if not opt_ctx.dtype or opt_ctx.dtype not in DTYPE_LOWP_FP:
  2674. return False
  2675. if _lowp_fp_type:
  2676. assert (
  2677. _lowp_fp_type == opt_ctx.dtype
  2678. ), "scheduler node do not support bf16/fp16 mix"
  2679. else:
  2680. _lowp_fp_type = opt_ctx.dtype
  2681. else:
  2682. return False
  2683. scheduler_node._lowp_fp_type = _lowp_fp_type # type: ignore[attr-defined]
  2684. return True
  2685. def legalize_lowp_fp_dtype(self, nodes):
  2686. def add_to_dtype(sub_graph: torch.fx.Graph):
  2687. def is_lowp_fp_load(node: torch.fx.Node):
  2688. if node.target not in ["load"]:
  2689. return False
  2690. assert len(node.args) == 3
  2691. load_dtype = V.graph.get_dtype(node.args[1]) # type: ignore[arg-type]
  2692. return load_dtype in DTYPE_LOWP_FP
  2693. def is_lowp_fp_store(node: torch.fx.Node):
  2694. if node.target != "store":
  2695. return False
  2696. _, store_var, _, _, _ = node.args
  2697. store_dtype = V.graph.get_dtype(store_var) # type: ignore[arg-type]
  2698. return store_dtype in DTYPE_LOWP_FP
  2699. sub_graph_nodes = list(sub_graph.nodes)
  2700. to_lowp_fp_legalized_nodes = []
  2701. for _node in sub_graph_nodes:
  2702. if is_lowp_fp_load(_node):
  2703. # No need to promote to float if all users are direct stores
  2704. if all(user.target == "store" for user in _node.users):
  2705. continue
  2706. ops = _node.args[0]
  2707. with sub_graph.inserting_after(_node):
  2708. to_type_node = sub_graph.call_method(
  2709. "to_dtype", args=(ops, _node, torch.float)
  2710. )
  2711. to_type_node_args = to_type_node.args
  2712. _node.replace_all_uses_with(to_type_node)
  2713. to_type_node.args = to_type_node_args
  2714. metrics.cpp_to_dtype_count += 1
  2715. elif is_lowp_fp_store(_node):
  2716. ops, name, _, value_var, _ = _node.args
  2717. # No need to promote to float if it is a user of a load which are all directly stored
  2718. if value_var.target == "load" and all(
  2719. user.target == "store" for user in value_var.users
  2720. ):
  2721. continue
  2722. dtype = V.graph.get_dtype(name)
  2723. with sub_graph.inserting_before(_node):
  2724. to_type_node = sub_graph.call_method(
  2725. "to_dtype", args=(ops, value_var, dtype)
  2726. )
  2727. _node.replace_input_with(value_var, to_type_node)
  2728. metrics.cpp_to_dtype_count += 1
  2729. elif _node.target == "reduction":
  2730. (
  2731. ops,
  2732. dtype,
  2733. src_dtype,
  2734. reduction_type,
  2735. value,
  2736. ) = _node.args
  2737. if src_dtype in DTYPE_LOWP_FP:
  2738. # Since we always convert the load/store value to float if the tensor is bfloat16/float16.
  2739. # Therefore, the reduction should never work with bfloat16/float16 value. Hence, we update
  2740. # the bfloat16/float16 reduction by
  2741. # 1) updating the src_dtype to float
  2742. # and 2) updating the dtype to float if it is bfloat16/float16.
  2743. assert dtype in [
  2744. torch.float,
  2745. torch.bfloat16,
  2746. torch.float16,
  2747. torch.int64,
  2748. ]
  2749. _node.args = (
  2750. ops,
  2751. torch.float if dtype in DTYPE_LOWP_FP else dtype,
  2752. torch.float,
  2753. reduction_type,
  2754. value,
  2755. )
  2756. elif _node.target == "to_dtype" and _node.args[-1] in DTYPE_LOWP_FP:
  2757. (ops, x, _) = _node.args
  2758. # The legalization always loads the BF16/FP16 tensor as FP32 for computation
  2759. # and converts back to BF16/FP16 after the computation.
  2760. # Hence, there should be no computation w/ BF16/FP16.
  2761. # Therefore, we update the to_dtype by replacing the bf16/fp16 dtype with fp32.
  2762. # Save the legalized to_dtype node for the elimination(eliminate_to_dtype step):
  2763. # 1) Eliminate the redundant to_dtype node if we have a pattern as follows:
  2764. # graph():
  2765. # %lowp_fp_legalized = call_method[target=to_dtype](args = (%ops, %input, torch.float))
  2766. # %to_dtype2 = call_method[target=to_dtype](args = (%ops, %lowp_fp_legalized, torch.bfloat16/float16))
  2767. # Regarding the first to_dtype, it is redundant because
  2768. # the second to_type also converts to the torch.bfloat16/torch.float16.
  2769. # Hence, we remove the first to_type.
  2770. to_lowp_fp_legalized_nodes.append(_node)
  2771. _node.args = (ops, x, torch.float)
  2772. else:
  2773. pass
  2774. def eliminate_to_dtype(sub_graph: torch.fx.Graph):
  2775. def _eliminate_duplicate_to_node(sub_graph: torch.fx.Graph):
  2776. # Eliminate the redundant to_dtype node. Let's consider a pattern as follows:
  2777. # graph():
  2778. # %to_dtype1 = call_method[target=to_dtype](args = (%ops, %input, torch.float), kwargs = {})
  2779. # %to_dtype2 = call_method[target=to_dtype](args = (%ops, %to_dtype1, torch.float), kwargs = {})
  2780. # Regarding the first to_dtype, it is redundant because the second to_type also converts to the
  2781. # torch.float. Hence, we remove the first to_type
  2782. def _used_by_to(to_node: torch.fx.Node):
  2783. return all(usr.target == "to_dtype" for usr in to_node.users)
  2784. all_to_nodes = [
  2785. node for node in sub_graph.nodes if node.target == "to_dtype"
  2786. ]
  2787. all_to_nodes_and_users = [
  2788. {node: node.users} for node in all_to_nodes if _used_by_to(node)
  2789. ]
  2790. for node_users in all_to_nodes_and_users:
  2791. for node, users in node_users.items():
  2792. if node in sub_graph.nodes and (
  2793. all(usr.args[-1] == node.args[-1] for usr in users)
  2794. or (
  2795. node in to_lowp_fp_legalized_nodes
  2796. and all(
  2797. usr.args[-1] in DTYPE_LOWP_FP for usr in users
  2798. )
  2799. )
  2800. ):
  2801. val_node = node.all_input_nodes[-1]
  2802. node.replace_all_uses_with(val_node)
  2803. sub_graph.erase_node(node)
  2804. # For debug mode, the graph of LoopBody will attach a new GraphModule as
  2805. # owning_module for debugging while the release mode will not. The lint will
  2806. # check whether the graph has owning_module to decide if it needs to check
  2807. # call_module. LoopBody might contain get_index as a module call. But it
  2808. # is just a function. Hence, it cannot pass the lint check for debug mode.
  2809. # We bypass the check if the owning_module is None. Eventually, we should call
  2810. # get_index via call_function but not call_module.
  2811. if sub_graph.owning_module is None:
  2812. sub_graph.lint()
  2813. _eliminate_duplicate_to_node(sub_graph)
  2814. eliminate_to_dtype(sub_graph)
  2815. def _legalize_lowp_fp(loop_body: ir.LoopBody):
  2816. sub_blocks = [loop_body.root_block] + list(loop_body.subblocks.values())
  2817. for sub_block in sub_blocks:
  2818. add_to_dtype(sub_block.graph)
  2819. if all(
  2820. isinstance(_node, SchedulerNode) and self.is_lowp_fp_scheduler(_node)
  2821. for _node in nodes
  2822. ):
  2823. # Mark the load node to load bf16/fp16
  2824. for _node in nodes:
  2825. sub_blocks = [_node._body.root_block] + list(
  2826. _node._body.subblocks.values()
  2827. )
  2828. for sub_block in sub_blocks:
  2829. for fx_node in sub_block.graph.nodes:
  2830. if fx_node.target in ["load", "store"]:
  2831. assert fx_node.meta
  2832. assert OptimizationContext.key in fx_node.meta
  2833. opt_ctx: OptimizationContext = fx_node.meta[
  2834. OptimizationContext.key
  2835. ]
  2836. assert opt_ctx.dtype in DTYPE_LOWP_FP
  2837. # Bypass the legalization as the kernel can run with bf16/fp16 directly
  2838. return
  2839. for _node in nodes:
  2840. assert isinstance(_node, SchedulerNode)
  2841. assert isinstance(_node._body, ir.LoopBody)
  2842. node: SchedulerNode = _node
  2843. def is_memory_copy_scheduler_node(node: SchedulerNode):
  2844. op_counts = node.read_writes.op_counts
  2845. return (
  2846. len(op_counts) == 2 and "load" in op_counts and "store" in op_counts
  2847. )
  2848. should_legalize = not is_memory_copy_scheduler_node(node)
  2849. if should_legalize:
  2850. body: ir.LoopBody = node._body
  2851. _legalize_lowp_fp(body)
  2852. def codegen_functions(self, fn_list, var_sizes_list, vec_dtype=torch.float):
  2853. # TODO(jgong5): remove vec_dtype arg with alternative tiling factors for various dtypes
  2854. assert len(fn_list) == len(var_sizes_list)
  2855. kernel_group = self.kernel_group
  2856. group, reduction_group = max(var_sizes_list, key=lambda sizes: len(sizes[1]))
  2857. self.set_ranges(group, reduction_group)
  2858. def codegen_kernel(cls, *args):
  2859. with kernel_group.new_kernel(cls, *args) as kernel:
  2860. # Ugly hack to maintain the metrics kernel count since
  2861. # we only count in CppKernelProxy, not those contained in it
  2862. metrics.generated_kernel_count -= 1
  2863. run(kernel)
  2864. return kernel
  2865. def run(kernel):
  2866. vars, reduction_vars = kernel.set_ranges(group, reduction_group)
  2867. in_suffix = False
  2868. for fn, var_sizes in zip(fn_list, var_sizes_list):
  2869. if var_sizes in [
  2870. (group, reduction_group),
  2871. (tuple(itertools.chain(group, reduction_group)), ()),
  2872. ]:
  2873. assert not in_suffix
  2874. fn(vars, reduction_vars)
  2875. else:
  2876. in_suffix = True
  2877. assert var_sizes == (
  2878. group,
  2879. (),
  2880. ), f"unexpected group: {var_sizes} != {group}, {reduction_group}"
  2881. # we can fuse in some extra pointwise into the suffix
  2882. with kernel.write_to_suffix():
  2883. fn(vars, ())
  2884. scalar_kernel = codegen_kernel(CppKernel)
  2885. V.graph.removed_buffers |= scalar_kernel.removed_buffers
  2886. V.graph.inplaced_to_remove |= scalar_kernel.inplaced_to_remove
  2887. self.loop_nest = LoopNestWithSplit.build(scalar_kernel)
  2888. if not self.picked_vec_isa:
  2889. return
  2890. def select_tiling_indices(tiling_factor):
  2891. all_index = []
  2892. for fn, var_sizes in zip(fn_list, var_sizes_list):
  2893. rw = dependencies.extract_read_writes(fn, *var_sizes)
  2894. all_index += [dep.index for dep in itertools.chain(rw.reads, rw.writes)]
  2895. contig_vars = set()
  2896. contig_vars_list = []
  2897. non_contig_stride_const = set()
  2898. non_contig_stride_other = set()
  2899. for index in all_index:
  2900. for var in index.free_symbols:
  2901. if not re.search(r"^d\d+$", var.name):
  2902. continue
  2903. stride = stride_at_vec_range(index, var, tiling_factor)
  2904. if stride == 0:
  2905. continue
  2906. elif stride == 1:
  2907. contig_vars.add(int(var.name[1:]))
  2908. contig_vars_list.append(int(var.name[1:]))
  2909. elif all(symbol_is_type(s, SymT.SIZE) for s in stride.free_symbols):
  2910. non_contig_stride_const.add(int(var.name[1:]))
  2911. else:
  2912. non_contig_stride_other.add(int(var.name[1:]))
  2913. contig_only = (
  2914. contig_vars - non_contig_stride_const - non_contig_stride_other
  2915. )
  2916. if len(contig_vars) == 0:
  2917. # no contiguous vars
  2918. return [len(self.itervars) - 1]
  2919. if contig_only:
  2920. return sorted(contig_only)[-1:]
  2921. contig_and_const_stride = (
  2922. contig_vars & non_contig_stride_const
  2923. ) - non_contig_stride_other
  2924. contig_vars_sorted = sorted(contig_vars)
  2925. if (
  2926. len(contig_vars_sorted) == 2
  2927. and contig_vars_sorted[-1] in contig_and_const_stride
  2928. and contig_vars_sorted[-1] == len(self.itervars) - 1
  2929. ):
  2930. return contig_vars_sorted
  2931. return sorted(contig_vars_sorted, key=contig_vars_list.count)[-1:]
  2932. def select_tiling(dtype: torch.dtype = torch.float):
  2933. # TODO(jgong5): support alternative tiling factors and data types
  2934. tiling_factor = self.picked_vec_isa.nelements(dtype=dtype)
  2935. tiling_indices = select_tiling_indices(tiling_factor)
  2936. if tiling_indices:
  2937. could_vec = True
  2938. for tiling_indice in tiling_indices:
  2939. with CppVecKernelChecker(
  2940. deepcopy(self.kernel_group.args),
  2941. parallel_num_threads(),
  2942. tiling_factor,
  2943. tiling_indice,
  2944. ) as vec_checker:
  2945. run(vec_checker)
  2946. could_vec = could_vec and vec_checker.simd_vec
  2947. if not could_vec:
  2948. break
  2949. if could_vec:
  2950. if len(tiling_indices) == 1:
  2951. return [tiling_factor], tiling_indices
  2952. if len(tiling_indices) == 2:
  2953. return [tiling_factor, tiling_factor], tiling_indices
  2954. return [], []
  2955. # Kernels share the same global contexts like V.graph.wrapper_code, V.kernel.args.
  2956. # But the generated scalar kernel has updated these global contexts. Hence, the other kernels
  2957. # should not do this again to avoid context conflict. By now, we only control the
  2958. # config.inplace_buffers. In the future, we could maintain more contexts.
  2959. with torch._inductor.config.patch(inplace_buffers=False):
  2960. tiling_factors, tiling_indices = select_tiling(vec_dtype)
  2961. assert len(tiling_factors) == len(tiling_indices)
  2962. if len(tiling_indices) == 1:
  2963. vec_kernel = codegen_kernel(
  2964. CppVecKernel, tiling_factors[0], tiling_indices[0], vec_dtype
  2965. )
  2966. metrics.generated_cpp_vec_kernel_count += 1
  2967. main_loop, tail_loop = self.loop_nest.split_with_tiling(
  2968. tiling_indices[0], factor=tiling_factors[0]
  2969. )
  2970. main_loop.set_kernel(vec_kernel)
  2971. tail_loop.set_kernel(scalar_kernel)
  2972. main_loop.simd_vec = True
  2973. tail_loop.simd_omp = True
  2974. # We chop the loop into two cubes by the nelements - main loop and tail loop.
  2975. # Regarding the main loop, it is straightforward that it could be vectorized with
  2976. # nelements. But for the tail loop, it still could be vectorized. For example,
  2977. # if the nelements is 8(256bits), then the tail loop still could be vectorized
  2978. # as 4(128bits).
  2979. tail_loop.simd_nelements = tiling_factors[0] // 2
  2980. elif len(tiling_indices) == 2:
  2981. assert (
  2982. tiling_indices[1] == len(self.itervars) - 1
  2983. and tiling_factors[0] == tiling_factors[1]
  2984. )
  2985. tile2d_kernel = codegen_kernel(
  2986. CppTile2DKernel, tiling_factors[0], tiling_indices, vec_dtype
  2987. )
  2988. vec_kernel = codegen_kernel(
  2989. CppVecKernel, tiling_factors[0], tiling_indices[0], vec_dtype
  2990. )
  2991. metrics.generated_cpp_vec_kernel_count += 2
  2992. outer_main_loop, outer_tail_loop = self.loop_nest.split_with_tiling(
  2993. tiling_indices[0], factor=tiling_factors[0]
  2994. )
  2995. outer_tail_loop.set_kernel(scalar_kernel)
  2996. (
  2997. inner_main_loop,
  2998. inner_tail_loop,
  2999. ) = outer_main_loop.split_with_tiling(
  3000. tiling_indices[1] - tiling_indices[0], factor=tiling_factors[0]
  3001. )
  3002. inner_main_loop.set_kernel(tile2d_kernel)
  3003. inner_tail_loop.set_kernel(vec_kernel)
  3004. def codegen_loop_bodies(self, loop_bodies, var_sizes_list):
  3005. # TODO(jgong5): support lowp legalization
  3006. for body in loop_bodies:
  3007. DataTypePropagation.propagate_loopbody(body)
  3008. self.codegen_functions(loop_bodies, var_sizes_list)
  3009. def codegen_nodes(self, nodes: List[SchedulerNode]):
  3010. # Legalize BF16 node by adding to_dtype explicitly
  3011. self.legalize_lowp_fp_dtype(nodes)
  3012. self.data_type_propagation(nodes)
  3013. assert len(nodes) >= 1
  3014. first_node = nodes[0]
  3015. vec_dtype = (
  3016. first_node._lowp_fp_type # type: ignore[attr-defined]
  3017. if all(
  3018. hasattr(_node, "_lowp_fp_type")
  3019. and _node._lowp_fp_type == first_node._lowp_fp_type # type: ignore[attr-defined]
  3020. for _node in nodes
  3021. )
  3022. else torch.float
  3023. )
  3024. def fn(node, *index_vars):
  3025. node.decide_inplace_update()
  3026. node.mark_run()
  3027. if isinstance(V.kernel, NullKernelHandler):
  3028. return node._body(*index_vars)
  3029. else:
  3030. return node.codegen(index_vars)
  3031. fn_list = [functools.partial(fn, node) for node in nodes]
  3032. var_sizes_list = [node.group[1] for node in nodes]
  3033. self.codegen_functions(fn_list, var_sizes_list, vec_dtype)
  3034. def codegen_loops(self, code, worksharing):
  3035. self.codegen_loops_impl(self.loop_nest, code, worksharing)
  3036. class OuterLoopFusedKernel(CppKernel):
  3037. def __init__(self, kernel_group):
  3038. super().__init__(kernel_group.args, kernel_group.ws.num_threads)
  3039. self.inner: List[LoopLevel] = []
  3040. def decide_parallel_depth(self, max_parallel_depth, threads) -> int:
  3041. kernels_parallel_depth = []
  3042. nested_kernels: List[List[CppKernel]] = [
  3043. loop.get_kernels() for loop in self.inner
  3044. ]
  3045. for kernels in nested_kernels:
  3046. # For any ScalarKernel, VecKernel, or Tile2DKernel,
  3047. # they should all have the same call_ranges
  3048. call_ranges = kernels[0].call_ranges
  3049. assert call_ranges is not None
  3050. assert all(kernel.call_ranges == call_ranges for kernel in kernels)
  3051. kernels_parallel_depth.append(
  3052. kernels[0].decide_parallel_depth(len(call_ranges), threads)
  3053. )
  3054. return min(
  3055. max_parallel_depth,
  3056. max(kernels_parallel_depth),
  3057. )
  3058. class ReasonFusedNodes(Enum):
  3059. SAME_VARS_REDUCE = "same_vars_reduce"
  3060. COMPATIBLE_REDUCTION = "compatible_reduction"
  3061. COMPATIBLE_RANGES_NO_REDUCTION = "compatible_ranges_no_reduction"
  3062. class CppScheduling(BaseScheduling):
  3063. # ctypes limits the number of args to 1024, refer to:
  3064. # https://github.com/python/cpython/commit/a285af7e626d1b81cf09f8b2bf7656f100bc1237
  3065. # We set a conservative threshold here.
  3066. MAX_FUSED_KERNEL_ARGS_NUM = 500
  3067. def __init__(self, scheduler):
  3068. self.scheduler = scheduler
  3069. self.reset_kernel_group()
  3070. self._ready_to_flush = False
  3071. def _set_flush_status(self, status: bool):
  3072. self._ready_to_flush = status
  3073. def group_fn(self, sizes):
  3074. return tuple(tuple(map(V.graph.sizevars.simplify, s)) for s in sizes)
  3075. def reset_kernel_group(self):
  3076. from .cpp_wrapper_cpu import CppWrapperCpu
  3077. self.kernel_group: Union[CppWrapperKernelGroup, KernelGroup]
  3078. if isinstance(V.graph.wrapper_code, CppWrapperCpu):
  3079. self.kernel_group = CppWrapperKernelGroup()
  3080. else:
  3081. self.kernel_group = KernelGroup()
  3082. def fuse(self, node1, node2):
  3083. if node1.is_foreach() or node2.is_foreach():
  3084. return ForeachKernelSchedulerNode.fuse(node1, node2)
  3085. elif node1.is_template():
  3086. assert not node2.is_template()
  3087. return FusedSchedulerNode.fuse(node1, node2)
  3088. else:
  3089. if (
  3090. self._why_fuse_nodes(node1, node2)
  3091. == ReasonFusedNodes.COMPATIBLE_RANGES_NO_REDUCTION
  3092. ):
  3093. assert isinstance(node1, (SchedulerNode, FusedSchedulerNode))
  3094. assert isinstance(node2, (SchedulerNode, FusedSchedulerNode))
  3095. _, (vars1, reduce1) = node1.group
  3096. _, (vars2, reduce2) = node2.group
  3097. assert reduce1 == () and reduce2 == (), (reduce1, reduce2)
  3098. def get_indexing_ranges_exprs(node):
  3099. if isinstance(node, FusedSchedulerNode):
  3100. assert len(node.snodes) > 0, node.snodes
  3101. var_ranges = None
  3102. indexing_exprs = set()
  3103. for snode in node.snodes:
  3104. v, exprs = get_indexing_ranges_exprs(snode)
  3105. if var_ranges is None:
  3106. var_ranges = v
  3107. assert var_ranges == v, (var_ranges, v, node.snodes)
  3108. indexing_exprs.update(exprs)
  3109. return var_ranges, list(indexing_exprs)
  3110. else:
  3111. assert isinstance(node, SchedulerNode)
  3112. comp_buffer = node.node
  3113. assert isinstance(comp_buffer, ir.ComputedBuffer)
  3114. _, body, _ = comp_buffer.get_default_sizes_body()
  3115. return body.var_ranges, list(body.indexing_exprs.values())
  3116. node_to_recomp = node1 if len(vars1) < len(vars2) else node2
  3117. assert isinstance(node_to_recomp, SchedulerNode)
  3118. ref_node = node2 if len(vars1) < len(vars2) else node1
  3119. extra_indexing_constraints = get_indexing_ranges_exprs(ref_node)
  3120. node_to_recomp.recompute_size_and_body(
  3121. extra_indexing_constraints=extra_indexing_constraints
  3122. )
  3123. _, (vars1, _) = node1.group
  3124. _, (vars2, _) = node2.group
  3125. assert vars1 == vars2, (vars1, vars2)
  3126. return FusedSchedulerNode.fuse(node1, node2)
  3127. elif self.can_fuse_vertical_outer_loop(node1, node2):
  3128. return OuterLoopFusedSchedulerNode.fuse(
  3129. node1, node2, self._get_outer_loop_fusion_depth(node1, node2)
  3130. )
  3131. else:
  3132. return FusedSchedulerNode.fuse(node1, node2)
  3133. def _why_fuse_nodes(self, node1, node2) -> Optional[ReasonFusedNodes]:
  3134. _, (vars1, reduce1) = node1.group
  3135. _, (vars2, reduce2) = node2.group
  3136. if vars1 == vars2 and reduce1 == reduce2:
  3137. return ReasonFusedNodes.SAME_VARS_REDUCE
  3138. if reduce1 == () and vars1 == vars2 + reduce2:
  3139. return ReasonFusedNodes.COMPATIBLE_REDUCTION
  3140. if self._can_fuse_nodes_with_compatible_ranges(node1, node2):
  3141. return ReasonFusedNodes.COMPATIBLE_RANGES_NO_REDUCTION
  3142. # TODO(jansel): allow fusion pointwise (vars1, ()) suffix?
  3143. return None
  3144. def _can_fuse_nodes_with_compatible_ranges(self, node1, node2):
  3145. # Here we try to fuse SchedulerNode/FusedSchedulerNode with compatible ranges
  3146. # e.g. (s0, s1, s2) and (s0 * s1 * s2)
  3147. _, (vars1, reduce1) = node1.group
  3148. _, (vars2, reduce2) = node2.group
  3149. c1 = reduce1 == () and reduce2 == ()
  3150. c2 = math.prod(vars1) == math.prod(vars2)
  3151. c3 = len(vars1) == 1 or len(vars2) == 1
  3152. if not (c1 and c2 and c3):
  3153. return False
  3154. node_to_recomp = node1 if len(vars1) < len(vars2) else node2
  3155. ref_node = node2 if len(vars1) < len(vars2) else node1
  3156. # We can not recompute sizes and body for nodes other than SchedulerNode
  3157. # TODO: we can extend fusion support with compatible ranges for FusedSchedulerNode
  3158. if isinstance(node_to_recomp, FusedSchedulerNode):
  3159. return False
  3160. # It may happen that node1 and node2 compatible number of elements
  3161. # but different original ranges, for example:
  3162. # {d0: s0, d1: s1, d2: s2} vs {d0: s0*s1*s2}
  3163. # See https://github.com/pytorch/pytorch/pull/120077/files#r1500427848 for more details
  3164. # TODO: we can fix if it allows us to CSE at least one of the variables
  3165. assert isinstance(node_to_recomp, SchedulerNode)
  3166. if isinstance(node_to_recomp.node, ir.TemplateBuffer):
  3167. return False
  3168. assert isinstance(node_to_recomp.node, ir.ComputedBuffer)
  3169. # node.data.get_size() is a cheaper version of node.get_read_writes().var_ranges
  3170. # but without variable name
  3171. ranges2 = node_to_recomp.node.data.get_size()
  3172. ranges1 = None
  3173. if isinstance(ref_node, FusedSchedulerNode):
  3174. ranges_set = set()
  3175. for snode in ref_node.snodes:
  3176. if isinstance(snode.node, ir.TemplateBuffer):
  3177. break
  3178. assert isinstance(snode.node, ir.ComputedBuffer)
  3179. ranges_set.add(tuple(snode.node.data.get_size()))
  3180. if len(ranges_set) != 1:
  3181. return False
  3182. ranges1 = list(next(iter(ranges_set)))
  3183. else:
  3184. assert isinstance(ref_node, SchedulerNode)
  3185. assert isinstance(ref_node.node, ir.ComputedBuffer)
  3186. ranges1 = ref_node.node.data.get_size()
  3187. if ranges1 != ranges2:
  3188. return False
  3189. return True
  3190. def _can_fuse_horizontal_impl(self, node1, node2):
  3191. assert isinstance(node1, (FusedSchedulerNode, SchedulerNode))
  3192. assert isinstance(node2, (FusedSchedulerNode, SchedulerNode))
  3193. if any(
  3194. isinstance(node, OuterLoopFusedSchedulerNode) for node in (node1, node2)
  3195. ):
  3196. return False
  3197. return self._why_fuse_nodes(node1, node2) is not None
  3198. def can_fuse_horizontal(self, node1, node2):
  3199. if node1.is_template() or node2.is_template():
  3200. return False
  3201. if (
  3202. len(node1.get_nodes()) + len(node2.get_nodes())
  3203. > config.cpp.max_horizontal_fusion_size
  3204. ):
  3205. return False
  3206. return self._can_fuse_horizontal_impl(node1, node2)
  3207. def _get_outer_loop_fusion_depth(self, node1, node2):
  3208. DISABLE_OUTER_LOOP_FUSION = 0
  3209. if not all(
  3210. type(node)
  3211. in (OuterLoopFusedSchedulerNode, FusedSchedulerNode, SchedulerNode)
  3212. for node in (node1, node2)
  3213. ):
  3214. return DISABLE_OUTER_LOOP_FUSION
  3215. _node1 = (
  3216. node1.get_outer_nodes()[-1]
  3217. if isinstance(node1, OuterLoopFusedSchedulerNode)
  3218. else node1
  3219. )
  3220. assert isinstance(_node1, (FusedSchedulerNode, SchedulerNode))
  3221. _node2 = (
  3222. node2.get_outer_nodes()[0]
  3223. if isinstance(node2, OuterLoopFusedSchedulerNode)
  3224. else node2
  3225. )
  3226. assert isinstance(_node2, (FusedSchedulerNode, SchedulerNode))
  3227. _, (vars1, reduce1) = _node1.group
  3228. _, (vars2, reduce2) = _node2.group
  3229. if vars1 == () and vars2 == () and reduce1 != () and reduce2 != ():
  3230. # Reduction only
  3231. return DISABLE_OUTER_LOOP_FUSION
  3232. if all(type(node) is OuterLoopFusedSchedulerNode for node in (node1, node2)):
  3233. return (
  3234. node1.outer_loop_fusion_depth
  3235. if node1.outer_loop_fusion_depth == node2.outer_loop_fusion_depth
  3236. else DISABLE_OUTER_LOOP_FUSION
  3237. )
  3238. outer_loop_fusion_depth = min(len(vars1), len(vars2))
  3239. if (
  3240. outer_loop_fusion_depth >= 1
  3241. and vars1[:outer_loop_fusion_depth] == vars2[:outer_loop_fusion_depth]
  3242. ):
  3243. if any(
  3244. type(node) is OuterLoopFusedSchedulerNode for node in (node1, node2)
  3245. ):
  3246. _compare_node = (
  3247. node1 if type(node1) is OuterLoopFusedSchedulerNode else node2
  3248. )
  3249. if _compare_node.outer_loop_fusion_depth == outer_loop_fusion_depth:
  3250. # Same outer loop fusion depth as prev nodes in OuterLoopFusedSchedulerNode
  3251. return outer_loop_fusion_depth
  3252. else:
  3253. return DISABLE_OUTER_LOOP_FUSION
  3254. else:
  3255. # First 2 nodes to generate OuterLoopFusedSchedulerNode
  3256. return outer_loop_fusion_depth
  3257. return DISABLE_OUTER_LOOP_FUSION
  3258. def can_fuse_vertical_outer_loop(self, node1, node2):
  3259. return (
  3260. not node1.is_template()
  3261. and not node2.is_template()
  3262. and node1.get_names() & node2.ancestors
  3263. and not (
  3264. self._can_fuse_horizontal_impl(node1, node2)
  3265. and not node1.is_reduction()
  3266. )
  3267. and self._get_outer_loop_fusion_depth(node1, node2) >= 1
  3268. )
  3269. def get_fusion_pair_priority(self, node1, node2):
  3270. if self.can_fuse_vertical_outer_loop(node1, node2):
  3271. # Outer loop fusion with lower priority
  3272. return 1
  3273. else:
  3274. return 0
  3275. def can_fuse_vertical(self, node1, node2):
  3276. if node2.is_template():
  3277. # TODO(jgong5): support pre-op fusion with template
  3278. return False
  3279. if node1.is_template():
  3280. return not node2.is_reduction()
  3281. return (
  3282. self._can_fuse_horizontal_impl(node1, node2) and not node1.is_reduction()
  3283. ) or self.can_fuse_vertical_outer_loop(node1, node2)
  3284. def codegen_node(
  3285. self,
  3286. node: Union[OuterLoopFusedSchedulerNode, FusedSchedulerNode, SchedulerNode],
  3287. ):
  3288. """
  3289. Turn an set of pre-fused nodes into a C++ kernel.
  3290. """
  3291. kernel_group = self.kernel_group
  3292. if isinstance(node, OuterLoopFusedSchedulerNode):
  3293. cpp_kernel_proxy_list: List[CppKernelProxy] = []
  3294. nodes_list: List[List[SchedulerNode]] = []
  3295. for _node in node.get_outer_nodes():
  3296. assert isinstance(_node, (FusedSchedulerNode, SchedulerNode))
  3297. _nodes: List[SchedulerNode] = _node.get_nodes() # type: ignore[assignment]
  3298. cpp_kernel_proxy = CppKernelProxy(kernel_group)
  3299. cpp_kernel_proxy.codegen_nodes(_nodes)
  3300. cpp_kernel_proxy_list.append(cpp_kernel_proxy)
  3301. nodes_list.append(_nodes)
  3302. # Note that, in the future, when every kernel can be vectorized,
  3303. # the function select_tiling will be much easier, and we'll be able to lift
  3304. # check_outer_fusion_loop_level_attr to the fusion phase,
  3305. # avoiding grouping kernels at fusion time that "look like we'll be able to fuse them"
  3306. # but then we actually won't.
  3307. if node.check_outer_fusion_loop_level_attr(
  3308. cpp_kernel_proxy_list, node.outer_loop_fusion_depth
  3309. ):
  3310. # Merge the cpp_kernel_proxy_list into cpp_kernel_proxy
  3311. outer_fusion_cpp_kernel_proxy = node.merge_outer_fusion_kernels(
  3312. cpp_kernel_proxy_list,
  3313. )
  3314. kernel_group.finalize_kernel(
  3315. outer_fusion_cpp_kernel_proxy,
  3316. [_node for _nodes in nodes_list for _node in _nodes],
  3317. )
  3318. else:
  3319. # Fall back to standard loop codegen
  3320. for _kernel_proxy, _nodes in zip(cpp_kernel_proxy_list, nodes_list):
  3321. kernel_group.finalize_kernel(_kernel_proxy, _nodes)
  3322. else:
  3323. nodes: List[SchedulerNode] = node.get_nodes() # type: ignore[assignment]
  3324. cpp_kernel_proxy = CppKernelProxy(kernel_group)
  3325. cpp_kernel_proxy.codegen_nodes(nodes)
  3326. kernel_group.finalize_kernel(cpp_kernel_proxy, nodes)
  3327. args_num = self._get_scheduled_num_args()
  3328. if args_num > CppScheduling.MAX_FUSED_KERNEL_ARGS_NUM:
  3329. self._set_flush_status(True)
  3330. def is_cpp_template(self, node: BaseSchedulerNode) -> bool:
  3331. return isinstance(node, SchedulerNode) and isinstance(
  3332. node.node, ir.CppTemplateBuffer
  3333. )
  3334. def codegen_template(
  3335. self,
  3336. template_node: BaseSchedulerNode,
  3337. epilogue_nodes: Sequence[BaseSchedulerNode],
  3338. ):
  3339. """
  3340. Codegen a CPP template, possibly with fused epilogues
  3341. """
  3342. counters["inductor"]["cpp_epilogue_fusion_counter"] += len(epilogue_nodes)
  3343. assert self.is_cpp_template(
  3344. template_node
  3345. ), "Template node passed to CppScheduler.codegen_template must be a SchedulerNode that wraps a CppTemplateBuffer"
  3346. template_node = cast(SchedulerNode, template_node)
  3347. _, (_, rnumel) = template_node.group
  3348. assert rnumel == ()
  3349. ctb: ir.CppTemplateBuffer = cast(ir.CppTemplateBuffer, template_node.node)
  3350. epilogue_ir_nodes: List[Optional[ir.Buffer]] = [n.node for n in epilogue_nodes]
  3351. assert all(
  3352. isinstance(n, ir.ComputedBuffer) for n in epilogue_ir_nodes
  3353. ), "Epilogue nodes must all be instances of ir.ComputedBuffer"
  3354. kernel, render = ctb.make_kernel_render(ctb, epilogue_nodes=epilogue_ir_nodes)
  3355. with kernel:
  3356. for node in [template_node, *epilogue_nodes]:
  3357. node.decide_inplace_update()
  3358. node.mark_run() # type: ignore[attr-defined]
  3359. src_code = render()
  3360. with V.set_kernel_handler(kernel):
  3361. node_schedule = [template_node, *epilogue_nodes]
  3362. kernel_name = self.define_kernel(src_code, node_schedule, kernel.args)
  3363. kernel.call_kernel(kernel_name, ctb)
  3364. V.graph.removed_buffers |= kernel.removed_buffers
  3365. self.scheduler.free_buffers()
  3366. def _get_scheduled_num_args(self):
  3367. return self.kernel_group.get_num_args()
  3368. def ready_to_flush(self):
  3369. return self._ready_to_flush
  3370. def codegen_sync(self):
  3371. pass
  3372. def define_kernel(self, src_code, nodes, kernel_args=None):
  3373. wrapper = V.graph.wrapper_code
  3374. fused_name = (
  3375. get_fused_kernel_name(nodes, config.cpp.descriptive_names)
  3376. if config.cpp.descriptive_names
  3377. else ""
  3378. )
  3379. kernel_name = "_".join(["cpp", fused_name, wrapper.next_kernel_suffix()])
  3380. kernel_decl_name = kernel_name if V.graph.cpp_wrapper else "kernel"
  3381. src_code = src_code.replace(str(Placeholder.KERNEL_NAME), kernel_decl_name)
  3382. src_code = src_code.replace(str(Placeholder.DESCRIPTIVE_NAME), kernel_name)
  3383. # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does
  3384. # not use BracesBuffer, so we have no good indicator of a C++ buffer atm.
  3385. src_code = src_code.replace("#pragma CMT", "//")
  3386. compile_wrapper = IndentedBuffer()
  3387. args = self.kernel_group.args if kernel_args is None else kernel_args
  3388. _, _, arg_types = args.cpp_argdefs()
  3389. if not V.graph.cpp_wrapper:
  3390. compile_wrapper.writeline(f"async_compile.cpp_pybinding({arg_types!r}, '''")
  3391. compile_wrapper.splice(src_code, strip=True)
  3392. if not V.graph.cpp_wrapper:
  3393. compile_wrapper.writeline("''')")
  3394. wrapper.define_kernel(kernel_name, compile_wrapper.getvalue(), cuda=False)
  3395. return kernel_name
  3396. def flush(self):
  3397. src_code = self.kernel_group.codegen_group()
  3398. if src_code:
  3399. kernel_name = self.define_kernel(
  3400. src_code, self.kernel_group.scheduled_nodes
  3401. )
  3402. self.kernel_group.call_kernel(V.graph.wrapper_code, kernel_name)
  3403. self.reset_kernel_group()
  3404. self._set_flush_status(False)
  3405. class KernelGroup:
  3406. def __init__(self):
  3407. super().__init__()
  3408. self.args = KernelArgs()
  3409. self.loops_code = BracesBuffer()
  3410. self.ws = WorkSharing(self.loops_code)
  3411. self.stack = contextlib.ExitStack()
  3412. self.stack.enter_context(self.ws)
  3413. self.scheduled_nodes = []
  3414. def new_kernel(self, cls, *args):
  3415. return cls(self.args, parallel_num_threads(), *args)
  3416. def finalize_kernel(self, new_kernel, nodes):
  3417. self.scheduled_nodes += nodes
  3418. code = self.loops_code
  3419. ws = self.ws
  3420. new_kernel.codegen_loops(code, ws)
  3421. def get_num_args(self):
  3422. arg_defs, call_args, arg_types = self.args.cpp_argdefs()
  3423. args_num = len(arg_defs)
  3424. return args_num
  3425. def codegen_group(self, name=None) -> str:
  3426. self.stack.close()
  3427. if not self.scheduled_nodes:
  3428. return ""
  3429. code = BracesBuffer()
  3430. # 1. Include header files
  3431. # TODO: support kernel profile on other platforms
  3432. enable_kernel_profile = (
  3433. config.cpp.enable_kernel_profile and sys.platform == "linux"
  3434. )
  3435. if enable_kernel_profile:
  3436. code.writelines(["#include <ATen/record_function.h>"])
  3437. code.writeline(codecache.cpp_prefix())
  3438. # 2. Function definition
  3439. kernel_decl_name = str(Placeholder.KERNEL_NAME) if name is None else name
  3440. kernel_name = str(Placeholder.DESCRIPTIVE_NAME) if name is None else name
  3441. arg_defs, _, _ = self.args.cpp_argdefs()
  3442. arg_defs = ",\n".ljust(25).join(arg_defs)
  3443. code.writeline(f'extern "C" void {kernel_decl_name}({arg_defs})')
  3444. # 3. Function body
  3445. with code.indent():
  3446. if enable_kernel_profile:
  3447. graph_id = V.graph.graph_id
  3448. prefix = "graph_" + str(graph_id) + "_" if graph_id is not None else ""
  3449. code.writelines(
  3450. [
  3451. f'RECORD_FUNCTION("{prefix + kernel_name}", c10::ArrayRef<c10::IValue>({{}}));'
  3452. ]
  3453. )
  3454. for old, new in self.args.aliases():
  3455. code.writeline(f"auto {old} = {new};")
  3456. code.splice(self.loops_code)
  3457. return code.getvalue()
  3458. def call_kernel(self, wrapper, kernel_name):
  3459. _, call_args, arg_types = self.args.cpp_argdefs()
  3460. wrapper.generate_kernel_call(
  3461. kernel_name, call_args, cuda=False, arg_types=arg_types
  3462. )
  3463. class CppWrapperKernelGroup(KernelGroup):
  3464. def __init__(self):
  3465. super().__init__()
  3466. self.args = CppWrapperKernelArgs()
  3467. class WorkSharing:
  3468. def __init__(self, code):
  3469. self.code = code
  3470. self.in_parallel = False
  3471. self.num_threads = None
  3472. self.stack = contextlib.ExitStack()
  3473. def parallel(self, threads):
  3474. if self.in_parallel and threads != self.num_threads:
  3475. # wrong number of threads
  3476. self.close()
  3477. if not self.in_parallel:
  3478. self.num_threads = threads
  3479. self.in_parallel = True
  3480. if config.cpp.dynamic_threads:
  3481. self.code.writeline("#pragma omp parallel")
  3482. else:
  3483. self.code.writeline(f"#pragma omp parallel num_threads({threads})")
  3484. self.stack.enter_context(self.code.indent())
  3485. self.code.writeline(
  3486. "int tid = omp_get_thread_num();",
  3487. )
  3488. def single(self):
  3489. if self.in_parallel:
  3490. self.code.writeline("#pragma omp single")
  3491. return self.in_parallel
  3492. def close(self):
  3493. self.stack.close()
  3494. self.in_parallel = False
  3495. def __enter__(self):
  3496. self.stack.__enter__()
  3497. return self
  3498. def __exit__(self, exc_type, exc_val, exc_tb):
  3499. self.stack.__exit__(exc_type, exc_val, exc_tb)
  3500. @dataclasses.dataclass
  3501. class LoopLevel:
  3502. var: Optional[sympy.Expr] = None
  3503. size: Optional[sympy.Expr] = None
  3504. offset: sympy.Expr = sympy.Integer(0)
  3505. steps: sympy.Expr = sympy.Integer(1)
  3506. parallel: int = 0
  3507. simd_omp: bool = False
  3508. simd_vec: bool = False
  3509. collapsed: bool = False
  3510. is_reduction: bool = False
  3511. parent: Optional["LoopLevel"] = None
  3512. # the next inner level of the loop, empty if it is inner-most
  3513. # contains >1 LoopLevel if the inner level of loop is split
  3514. inner: List["LoopLevel"] = dataclasses.field(default_factory=list)
  3515. # kernel assigned to this loop level, only valid when it is a leaf
  3516. kernel: Optional[CppKernel] = None
  3517. def __post_init__(self):
  3518. # Regarding the C++/OpenMP backend, `codecache.pick_vec_isa()` to check
  3519. # vectorization ISA is a time-consuming and one-shot operation. It leads
  3520. # to taking a longer time to import `codegen.cpp` package because the
  3521. # `LoopLevel` of the package is decorated by `@dataclasses.dataclass` while
  3522. # the decorator will invoke `codecache.pick_vec_isa()` to initialize the
  3523. # `simd_nelements` of the `LoopLevel`. It might introduce additional compilation
  3524. # overhead to the Triton backend. Therefore, we moved the `simd_nelements` to
  3525. # `__post_init__`
  3526. picked_vec_isa: codecache.VecISA = codecache.pick_vec_isa()
  3527. self.simd_nelements: int = picked_vec_isa.nelements() if picked_vec_isa else 0
  3528. def get_kernels(self) -> List[CppKernel]:
  3529. """Get all kernel objects under this loop level"""
  3530. if self.kernel:
  3531. return [self.kernel]
  3532. kernels = []
  3533. for loop in self.inner:
  3534. kernels += loop.get_kernels()
  3535. return kernels
  3536. def get_root(self):
  3537. """Get all kernel objects under this loop level"""
  3538. root = self
  3539. while root.parent:
  3540. root = root.parent
  3541. return root
  3542. def set_kernel(self, kernel: CppKernel):
  3543. """
  3544. Set the kernel under this loop level. No split is allowed under
  3545. this loop level.
  3546. """
  3547. if not self.inner:
  3548. self.kernel = kernel
  3549. loop: Optional[LoopLevel] = self
  3550. assert loop is not None
  3551. return
  3552. assert len(self.inner) == 1
  3553. self.inner[0].set_kernel(kernel)
  3554. def get_loops_at(self, depth) -> List["LoopLevel"]:
  3555. if depth == 0:
  3556. return [self]
  3557. else:
  3558. loops = []
  3559. for loop in self.inner:
  3560. loops += loop.get_loops_at(depth - 1)
  3561. return loops
  3562. def split_with_tiling(self, depth, factor):
  3563. def clone_inner():
  3564. inner = []
  3565. if self.inner:
  3566. for loop in self.inner:
  3567. inner.append(loop.clone())
  3568. return inner
  3569. def do_split_with_tiling():
  3570. sympy_factor = sympy.Integer(factor)
  3571. offset = FloorDiv(self.size, sympy_factor) * sympy_factor
  3572. main_loop = LoopLevel(self.var, offset)
  3573. main_loop.steps = sympy_factor
  3574. main_loop.parallel = self.parallel
  3575. main_loop.collapsed = False
  3576. main_loop.is_reduction = self.is_reduction
  3577. main_loop.inner = clone_inner()
  3578. if main_loop.inner:
  3579. for loop in main_loop.inner:
  3580. loop.parent = main_loop
  3581. tail_loop = LoopLevel(self.var, self.size)
  3582. tail_loop.offset = offset
  3583. tail_loop.parallel = self.parallel
  3584. tail_loop.collapsed = False
  3585. tail_loop.is_reduction = self.is_reduction
  3586. tail_loop.inner = clone_inner()
  3587. if tail_loop.inner:
  3588. for loop in tail_loop.inner:
  3589. loop.parent = tail_loop
  3590. return main_loop, tail_loop
  3591. if depth == 0:
  3592. main_loop, tail_loop = do_split_with_tiling()
  3593. parent = self.parent
  3594. if parent:
  3595. parent.inner = [main_loop, tail_loop]
  3596. main_loop.parent = parent
  3597. tail_loop.parent = parent
  3598. return main_loop, tail_loop
  3599. else:
  3600. assert len(self.inner) == 1
  3601. return self.inner[0].split_with_tiling(depth - 1, factor)
  3602. def clone(self):
  3603. loop = copy(self)
  3604. loop.inner = []
  3605. if self.inner:
  3606. for inner_loop in self.inner:
  3607. inner_loop_clone = inner_loop.clone()
  3608. inner_loop_clone.parent = loop
  3609. loop.inner.append(inner_loop_clone)
  3610. loop.kernel = deepcopy(self.kernel)
  3611. return loop
  3612. def lines(self):
  3613. offset_expr = cexpr_index(self.offset)
  3614. size_expr = cexpr_index(self.size)
  3615. if config.cpp.no_redundant_loops and offset_expr == size_expr:
  3616. return None
  3617. simd = (
  3618. f"simd simdlen({self.simd_nelements}) "
  3619. if self.simd_omp and self.simd_nelements > 1
  3620. else ""
  3621. )
  3622. if self.parallel:
  3623. # TODO(jansel): look into chunk size and other schedules
  3624. line1 = "#pragma omp for"
  3625. if self.parallel > 1:
  3626. line1 += f" collapse({self.parallel})"
  3627. if self.simd_omp:
  3628. line1 = line1.replace(" for ", f" for {simd}")
  3629. elif self.simd_vec:
  3630. line1 = ""
  3631. elif self.simd_omp:
  3632. line1 = f"#pragma omp {simd}"
  3633. elif not self.is_reduction and codecache.is_gcc():
  3634. line1 = "#pragma GCC ivdep"
  3635. else:
  3636. line1 = ""
  3637. offset_str = f"{INDEX_TYPE} {self.var}={offset_expr}"
  3638. size_str = f"{self.var}<{size_expr}"
  3639. steps_str = f"{self.var}+={cexpr_index(self.steps)}"
  3640. line2 = f"for({offset_str}; {size_str}; {steps_str})"
  3641. if self.collapsed or not line1:
  3642. return [line2]
  3643. return [line1, line2]
  3644. @dataclasses.dataclass
  3645. class LoopNestWithSplit:
  3646. """
  3647. A loop-nest like structure but with some loop level split along
  3648. the loop range into the main tiling loop and the tail. It is built
  3649. with the `build` method as a loop nest and then split with
  3650. `split_with_tiling` at some depth.
  3651. A typical case is for vectorization where we typically split at the inner-most
  3652. loop level. A more complicated case is 2D tiling where we split at
  3653. both inner-most and outer levels.
  3654. """
  3655. root: Optional[List[LoopLevel]] = None
  3656. kernel: Optional[CppKernel] = None
  3657. @staticmethod
  3658. def build(kernel: CppKernel):
  3659. """Build a LoopNest with the given `kernel` as the leaf"""
  3660. itervars = kernel.itervars
  3661. ranges = kernel.ranges
  3662. reduction_depth = kernel.reduction_depth
  3663. assert reduction_depth is not None
  3664. root: List[LoopLevel] = []
  3665. levels: List[LoopLevel] = root
  3666. loop: Optional[LoopLevel] = None
  3667. for loop_idx, (var, size) in enumerate(zip(itervars, ranges)):
  3668. loop = LoopLevel(var, size, parent=loop)
  3669. if loop_idx >= reduction_depth:
  3670. loop.is_reduction = kernel.is_reduction
  3671. levels.append(loop)
  3672. levels = loop.inner
  3673. loop_nest = LoopNestWithSplit(root)
  3674. if loop:
  3675. loop.kernel = kernel
  3676. else:
  3677. loop_nest.kernel = kernel
  3678. return loop_nest
  3679. def __bool__(self):
  3680. return bool(self.root)
  3681. def get_loops_at(self, depth) -> List[LoopLevel]:
  3682. """Get all the loop levels at the given `depth` (most outer loop has depth 0)"""
  3683. loops: List[LoopLevel] = []
  3684. assert self.root is not None
  3685. for loop in self.root:
  3686. loops += loop.get_loops_at(depth)
  3687. return loops
  3688. @cache_on_self
  3689. def max_parallel_depth(self):
  3690. """
  3691. Maximal allowed depth for parallelism:
  3692. 1) Levels without splitting and
  3693. 2) All reduction or non-reduction levels
  3694. When the loop is split at the top level, the max depth is 1.
  3695. """
  3696. max_depth = 0
  3697. assert self.root is not None
  3698. loops = self.root
  3699. if len(loops) > 1:
  3700. return 1
  3701. is_reduction = loops[0].is_reduction if loops else False
  3702. while len(loops) == 1 and loops[0].is_reduction == is_reduction:
  3703. max_depth += 1
  3704. loops = loops[0].inner
  3705. return max_depth
  3706. def is_reduction_only(self):
  3707. """
  3708. Whether all the loops are for reduction. Reduction loops
  3709. are always the inner most ones.
  3710. """
  3711. return (
  3712. self.root is not None and len(self.root) > 0 and self.root[0].is_reduction
  3713. )
  3714. def mark_parallel(self, par_depth):
  3715. assert (
  3716. par_depth <= self.max_parallel_depth()
  3717. ), "Parallel depth cannot exceed the maximal allowed parallel depth"
  3718. assert self.root is not None
  3719. loops = self.root
  3720. for loop in loops:
  3721. loop.parallel = par_depth
  3722. for i in range(1, par_depth):
  3723. loops = loops[0].inner
  3724. loops[0].collapsed = True
  3725. def split_with_tiling(self, depth, factor):
  3726. """
  3727. Split the loop into main and tail loops at given `depth` so that the range
  3728. of the main loop has range `floor_div(range, factor) * factor` and
  3729. the tail loop handles the remainder. The main loop is tiled
  3730. according to the `factor`.
  3731. """
  3732. loops = self.get_loops_at(depth)
  3733. assert len(loops) == 1
  3734. split_loops = loops[0].split_with_tiling(0, factor)
  3735. if depth == 0:
  3736. self.root = split_loops
  3737. return split_loops
  3738. def get_kernels(self) -> List[CppKernel]:
  3739. """Get all kernel objects under this loop nest"""
  3740. if self.kernel:
  3741. return [self.kernel]
  3742. kernels: List[CppKernel] = []
  3743. assert self.root is not None
  3744. for loop in self.root:
  3745. kernels += loop.get_kernels()
  3746. return kernels