xnnpack.h 218 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358335933603361336233633364336533663367336833693370337133723373337433753376337733783379338033813382338333843385338633873388338933903391339233933394339533963397339833993400340134023403340434053406340734083409341034113412341334143415341634173418341934203421342234233424342534263427342834293430343134323433343434353436343734383439344034413442344334443445344634473448344934503451345234533454345534563457345834593460346134623463346434653466346734683469347034713472347334743475347634773478347934803481348234833484348534863487348834893490349134923493349434953496349734983499350035013502350335043505350635073508350935103511351235133514351535163517351835193520352135223523352435253526352735283529353035313532353335343535353635373538353935403541354235433544354535463547354835493550355135523553355435553556355735583559356035613562356335643565356635673568356935703571357235733574357535763577357835793580358135823583358435853586358735883589359035913592359335943595359635973598359936003601360236033604360536063607360836093610361136123613361436153616361736183619362036213622362336243625362636273628362936303631363236333634363536363637363836393640364136423643364436453646364736483649365036513652365336543655365636573658365936603661366236633664366536663667366836693670367136723673367436753676367736783679368036813682368336843685368636873688368936903691369236933694369536963697369836993700370137023703370437053706370737083709371037113712371337143715371637173718371937203721372237233724372537263727372837293730373137323733373437353736373737383739374037413742374337443745374637473748374937503751375237533754375537563757375837593760376137623763376437653766376737683769377037713772377337743775377637773778377937803781378237833784378537863787378837893790379137923793379437953796379737983799380038013802380338043805380638073808380938103811381238133814381538163817381838193820382138223823382438253826382738283829383038313832383338343835383638373838383938403841384238433844384538463847384838493850385138523853385438553856385738583859386038613862386338643865386638673868386938703871387238733874387538763877387838793880388138823883388438853886388738883889389038913892389338943895389638973898389939003901390239033904390539063907390839093910391139123913391439153916391739183919392039213922392339243925392639273928392939303931393239333934393539363937393839393940394139423943394439453946394739483949395039513952395339543955395639573958395939603961396239633964396539663967396839693970397139723973397439753976397739783979398039813982398339843985398639873988398939903991399239933994399539963997399839994000400140024003400440054006400740084009401040114012401340144015401640174018401940204021402240234024402540264027402840294030403140324033403440354036403740384039404040414042404340444045404640474048404940504051405240534054405540564057405840594060406140624063406440654066406740684069407040714072407340744075407640774078407940804081408240834084408540864087408840894090409140924093409440954096409740984099410041014102410341044105410641074108410941104111411241134114411541164117411841194120412141224123412441254126412741284129413041314132413341344135413641374138413941404141414241434144414541464147414841494150415141524153415441554156415741584159416041614162416341644165416641674168416941704171417241734174417541764177417841794180418141824183418441854186418741884189419041914192419341944195419641974198419942004201420242034204420542064207420842094210421142124213421442154216421742184219422042214222422342244225422642274228422942304231423242334234423542364237423842394240424142424243424442454246424742484249425042514252425342544255425642574258425942604261426242634264426542664267426842694270427142724273427442754276427742784279428042814282428342844285428642874288428942904291429242934294429542964297429842994300430143024303430443054306430743084309431043114312431343144315431643174318431943204321432243234324432543264327432843294330433143324333433443354336433743384339434043414342434343444345434643474348434943504351435243534354435543564357435843594360436143624363436443654366436743684369437043714372437343744375437643774378437943804381438243834384438543864387438843894390439143924393439443954396439743984399440044014402440344044405440644074408440944104411441244134414441544164417441844194420442144224423442444254426442744284429443044314432443344344435443644374438443944404441444244434444444544464447444844494450445144524453445444554456445744584459446044614462446344644465446644674468446944704471447244734474447544764477447844794480448144824483448444854486448744884489449044914492449344944495449644974498449945004501450245034504450545064507450845094510451145124513451445154516451745184519452045214522452345244525452645274528452945304531453245334534453545364537453845394540454145424543454445454546454745484549455045514552455345544555455645574558455945604561456245634564456545664567456845694570457145724573457445754576457745784579458045814582458345844585458645874588458945904591459245934594459545964597459845994600460146024603460446054606460746084609461046114612461346144615461646174618461946204621462246234624462546264627462846294630463146324633463446354636463746384639464046414642464346444645464646474648464946504651465246534654465546564657465846594660466146624663466446654666466746684669467046714672467346744675467646774678467946804681468246834684468546864687468846894690469146924693469446954696469746984699470047014702470347044705470647074708470947104711471247134714471547164717471847194720472147224723472447254726472747284729473047314732473347344735473647374738473947404741474247434744474547464747474847494750475147524753475447554756475747584759476047614762476347644765476647674768476947704771477247734774477547764777477847794780478147824783478447854786478747884789479047914792479347944795479647974798479948004801480248034804480548064807480848094810481148124813481448154816481748184819482048214822482348244825482648274828482948304831483248334834483548364837483848394840484148424843484448454846484748484849485048514852485348544855485648574858485948604861486248634864486548664867486848694870487148724873487448754876487748784879488048814882488348844885488648874888488948904891489248934894489548964897489848994900490149024903490449054906490749084909491049114912491349144915491649174918491949204921492249234924492549264927492849294930493149324933493449354936493749384939494049414942494349444945494649474948494949504951495249534954495549564957495849594960496149624963496449654966496749684969497049714972497349744975497649774978497949804981498249834984498549864987498849894990499149924993499449954996499749984999500050015002500350045005500650075008500950105011501250135014501550165017501850195020502150225023502450255026502750285029503050315032503350345035503650375038503950405041504250435044504550465047504850495050505150525053505450555056505750585059506050615062506350645065506650675068506950705071507250735074507550765077507850795080508150825083508450855086508750885089509050915092509350945095509650975098509951005101510251035104510551065107510851095110511151125113511451155116511751185119512051215122512351245125512651275128512951305131513251335134513551365137513851395140514151425143514451455146514751485149515051515152515351545155515651575158515951605161516251635164516551665167516851695170517151725173517451755176517751785179518051815182518351845185518651875188518951905191519251935194519551965197519851995200520152025203520452055206520752085209521052115212521352145215521652175218521952205221522252235224522552265227522852295230523152325233523452355236523752385239524052415242524352445245524652475248524952505251525252535254525552565257525852595260526152625263526452655266526752685269527052715272527352745275527652775278527952805281528252835284528552865287528852895290529152925293529452955296529752985299530053015302530353045305530653075308530953105311531253135314531553165317531853195320532153225323532453255326532753285329533053315332533353345335533653375338533953405341534253435344534553465347534853495350535153525353535453555356535753585359536053615362536353645365536653675368536953705371537253735374537553765377537853795380538153825383538453855386538753885389539053915392539353945395539653975398539954005401540254035404540554065407540854095410541154125413541454155416541754185419542054215422542354245425542654275428542954305431543254335434543554365437543854395440544154425443544454455446544754485449545054515452545354545455545654575458545954605461546254635464546554665467546854695470547154725473547454755476547754785479548054815482548354845485548654875488548954905491549254935494549554965497549854995500550155025503550455055506550755085509551055115512551355145515551655175518551955205521552255235524552555265527552855295530553155325533553455355536553755385539554055415542554355445545554655475548554955505551555255535554555555565557555855595560556155625563556455655566556755685569557055715572557355745575557655775578557955805581558255835584558555865587558855895590559155925593559455955596559755985599560056015602560356045605560656075608560956105611561256135614561556165617561856195620562156225623562456255626562756285629563056315632563356345635563656375638563956405641564256435644564556465647564856495650565156525653565456555656565756585659566056615662566356645665566656675668566956705671567256735674567556765677567856795680568156825683568456855686568756885689569056915692569356945695569656975698569957005701570257035704570557065707570857095710571157125713571457155716571757185719572057215722572357245725572657275728572957305731573257335734573557365737573857395740574157425743574457455746574757485749575057515752575357545755575657575758575957605761576257635764576557665767576857695770577157725773577457755776577757785779578057815782578357845785578657875788578957905791579257935794579557965797579857995800580158025803580458055806580758085809581058115812581358145815581658175818581958205821582258235824582558265827582858295830583158325833583458355836583758385839584058415842584358445845584658475848584958505851585258535854585558565857585858595860586158625863586458655866586758685869587058715872587358745875587658775878587958805881588258835884588558865887588858895890589158925893589458955896589758985899590059015902590359045905590659075908590959105911591259135914591559165917591859195920592159225923592459255926592759285929593059315932593359345935593659375938593959405941594259435944594559465947594859495950595159525953595459555956595759585959596059615962596359645965596659675968596959705971597259735974597559765977597859795980598159825983598459855986598759885989599059915992599359945995599659975998599960006001600260036004600560066007600860096010601160126013601460156016601760186019602060216022602360246025602660276028602960306031603260336034603560366037603860396040604160426043604460456046604760486049605060516052605360546055605660576058605960606061606260636064606560666067606860696070607160726073607460756076607760786079608060816082608360846085608660876088608960906091609260936094609560966097609860996100610161026103610461056106610761086109611061116112611361146115611661176118611961206121612261236124612561266127612861296130613161326133613461356136613761386139614061416142614361446145614661476148614961506151615261536154615561566157615861596160616161626163616461656166616761686169617061716172
  1. // Copyright (c) Facebook, Inc. and its affiliates.
  2. // All rights reserved.
  3. //
  4. // Copyright 2019 Google LLC
  5. //
  6. // This source code is licensed under the BSD-style license found in the
  7. // LICENSE file in the root directory of this source tree.
  8. #pragma once
  9. #include <stdbool.h>
  10. #include <stddef.h>
  11. #include <stdint.h>
  12. #include <pthreadpool.h>
  13. #ifdef __cplusplus
  14. extern "C" {
  15. #endif
  16. /// The number of bytes XNNPACK may read beyond array bounds.
  17. /// The caller must allocate at least this many extra bytes after the tensor data passed to XNNPACK.
  18. ///
  19. /// Note: XNNPACK reads, but never writes beyond array bounds.
  20. #define XNN_EXTRA_BYTES 16
  21. /// Maximum number of dimensions in tensor shape.
  22. #define XNN_MAX_TENSOR_DIMS 6
  23. /// Allow sparse inference in a Runtime.
  24. ///
  25. /// Note: this flag hints XNNPACK to consider sparse inference, but does not guarantee it.
  26. #define XNN_FLAG_HINT_SPARSE_INFERENCE 0x00000001
  27. /// Allow IEEE FP16 inference in a Runtime.
  28. ///
  29. /// Note: this flag hints XNNPACK to consider IEEE FP16 inference, but does not guarantee it.
  30. #define XNN_FLAG_HINT_FP16_INFERENCE 0x00000002
  31. /// Force IEEE FP16 inference in a Runtime, and fail if FP16 inference is not possible.
  32. ///
  33. /// Note: this flag guarantees that XNNPACK will use IEEE FP16 inference, or fail to create the Runtime object.
  34. /// Warning: on x86 systems FP16 computations will be emulated at a substantial performance cost.
  35. #define XNN_FLAG_FORCE_FP16_INFERENCE 0x00000004
  36. /// Enable timing of each operator's runtime.
  37. #define XNN_FLAG_BASIC_PROFILING 0x00000008
  38. /// Enable the just-in-time compiler.
  39. #define XNN_FLAG_JIT 0x00000010
  40. /// The convolution operator represents a depthwise convolution, and use HWGo layout for filters.
  41. #define XNN_FLAG_DEPTHWISE_CONVOLUTION 0x00000001
  42. /// Assume transposed weights in a fully connected operator.
  43. #define XNN_FLAG_TRANSPOSE_WEIGHTS 0x00000001
  44. /// The operator assumes NHWC layout for the input, regardless of the output layout.
  45. #define XNN_FLAG_INPUT_NHWC 0x00000002
  46. /// Match "SAME" padding in TensorFlow. Exact padding values are computed dynamically depending on input size.
  47. #define XNN_FLAG_TENSORFLOW_SAME_PADDING 0x00000004
  48. /// Assume transposed weights in a batch matrix multiply operator.
  49. #define XNN_FLAG_TRANSPOSE_B XNN_FLAG_TRANSPOSE_WEIGHTS
  50. /// Assume transposed input in a batch matrix multiply operator.
  51. #define XNN_FLAG_TRANSPOSE_A 0x00000002
  52. /// Implicitly flatten and reshape input of a Fully Connected operator into a 2D tensor.
  53. #define XNN_FLAG_TENSORFLOW_RESHAPE_2D 0x00000004
  54. /// Match behaviour of TensorFlow 1.x.
  55. #define XNN_FLAG_TENSORFLOW_LEGACY_MODE 0x00000004
  56. /// Static weights of the FP16 operator are in FP32 format.
  57. #define XNN_FLAG_FP32_STATIC_WEIGHTS 0x00000008
  58. /// Align corners of input and output images in resize operations.
  59. #define XNN_FLAG_ALIGN_CORNERS 0x00000008
  60. /// Yield worker threads of the thread pool to the system scheduler after the inference.
  61. #define XNN_FLAG_YIELD_WORKERS 0x00000010
  62. /// Use transient indirection buffer to reduce memory footprint
  63. #define XNN_FLAG_TRANSIENT_INDIRECTION_BUFFER 0x00000020
  64. /// Reduce the dimensions.
  65. #define XNN_FLAG_REDUCE_DIMS 0x00000040
  66. /// The number of entries in an array of xnn_dynamic_quantization_params that XNNPACK may read beyond array bounds.
  67. /// The caller must allocate at least this many extra xnn_dynamic_quantization_params before passing the array to XNNPACK.
  68. ///
  69. /// Note: XNNPACK reads, but never writes beyond array bounds.
  70. #define XNN_EXTRA_QUANTIZATION_PARAMS 8
  71. struct xnn_dynamic_quantization_params {
  72. int32_t zero_point;
  73. float scale;
  74. };
  75. /// Status code for any XNNPACK function call.
  76. enum xnn_status {
  77. /// The call succeeded, and all output arguments now contain valid data.
  78. xnn_status_success = 0,
  79. xnn_status_uninitialized = 1,
  80. xnn_status_invalid_parameter = 2,
  81. xnn_status_invalid_state = 3,
  82. xnn_status_unsupported_parameter = 4,
  83. xnn_status_unsupported_hardware = 5,
  84. xnn_status_out_of_memory = 6,
  85. xnn_status_reallocation_required = 7,
  86. };
  87. struct xnn_allocator {
  88. /// User-specified pointer that will be passed as-is to all functions in this structure.
  89. void* context;
  90. /// Pointer to a function to be called for general memory allocation.
  91. ///
  92. /// @param context - The user-specified pointer from xnn_allocator structure.
  93. /// @param size - The size of the memory block to allocate, in bytes.
  94. ///
  95. /// @returns Pointer to the allocated memory block of at least @ref size bytes.
  96. /// If allocation fails, the function must return NULL.
  97. void* (*allocate)(void* context, size_t size);
  98. /// Pointer to a function to be called for general memory re-allocation, i.e. to increase or shrink a previously
  99. /// allocated memory block. The content of the old memory block is copied to the new memory block.
  100. ///
  101. /// @param context - The user-specified pointer from xnn_allocator structure.
  102. /// @param pointer - Pointer to a memory block allocated by @ref allocate or @ref reallocate functions. Can be NULL.
  103. /// If the pointer is NULL, the @ref reallocate call is equivalent to an @ref allocate call.
  104. /// @param size - The new size of the memory block to allocate, in bytes.
  105. ///
  106. /// @returns Pointer to the newly allocated memory block of at least @ref size bytes with the content of the previous
  107. /// memory block.
  108. /// If allocation fails, the function must return NULL, but must not release the previous memory block.
  109. void* (*reallocate)(void* context, void* pointer, size_t size);
  110. /// Pointer to a function to be called for general memory de-allocation.
  111. ///
  112. /// @param context - The user-specified pointer from xnn_allocator structure.
  113. /// @param pointer - Pointer to a memory block allocated by @ref allocate or @ref reallocate functions. Can be NULL.
  114. /// If the pointer is NULL, the @ref deallocate call is a no-op.
  115. void (*deallocate)(void* context, void* pointer);
  116. /// Pointer to a function to be called for aligned memory allocation.
  117. ///
  118. /// @param context - The user-specified pointer from xnn_allocator structure.
  119. /// @param alignment - The alignment of the memory block to allocate, in bytes. Alignment is always a power-of-2.
  120. /// @param size - The size of the memory block to allocate, in bytes.
  121. ///
  122. /// @returns Pointer to the allocated memory block of at least @ref size bytes.
  123. /// If allocation fails, the function must return NULL.
  124. void* (*aligned_allocate)(void* context, size_t alignment, size_t size);
  125. /// Pointer to a function to be called for aligned memory de-allocation.
  126. ///
  127. /// @param context - The user-specified pointer from xnn_allocator structure.
  128. /// @param pointer - Pointer to a memory block allocated by @ref aligned_allocate function. Can be NULL.
  129. /// If the pointer is NULL, the @ref aligned_deallocate call is a no-op.
  130. void (*aligned_deallocate)(void* context, void* pointer);
  131. };
  132. /// Initialize XNNPACK library.
  133. ///
  134. /// XNNPACK must be successfully initialized before use. During initialization, XNNPACK populates internal structures
  135. /// depending on the host processor. Initialization can be time-consuming.
  136. ///
  137. /// @param[in] allocator - structure with function pointers to be use for memory allocation and de-allocation.
  138. /// If this argument is NULL, system-provided memory management functions (e.g. malloc/free)
  139. /// will be used.
  140. ///
  141. /// @retval xnn_status_success - XNNPACK is successfully initialized and ready to use.
  142. /// @retval xnn_status_out_of_memory - initialization failed due to out-of-memory condition.
  143. /// @retval xnn_status_unsupported_hardware - initialization failed because the host processor does not satisfy the
  144. /// minimum hardware requirements for XNNPACK. E.g. this may happen on x86
  145. /// processors without SSE2 extension, or on 32-bit ARM processors without
  146. /// the NEON SIMD extension.
  147. enum xnn_status xnn_initialize(const struct xnn_allocator* allocator);
  148. /// Deinitialize XNNPACK library.
  149. ///
  150. /// To avoid memory and resource leaks, users must call xnn_deinitialize once for each successful xnn_initialize call.
  151. ///
  152. /// @retval xnn_status_success - deinitialization call succeeded.
  153. enum xnn_status xnn_deinitialize(void);
  154. /// Subgraph is an abstract representation of a neural network model.
  155. /// Subgraph objects are used to define Values (tensors) and Nodes (operators) comprising the model.
  156. typedef struct xnn_subgraph* xnn_subgraph_t;
  157. /// Create a empty Subgraph object.
  158. ///
  159. /// @param external_value_ids - number of Value IDs to reserve for communication with external graph representation.
  160. /// The Subgraph object would avoid creating internal Value IDs in the
  161. /// [0, reserved_value_ids-1] range.
  162. /// @param flags - binary features of the subgraph. No supported flags are currently defined.
  163. /// @param subgraph_out - pointer to the variable that will be initialized with a handle to the Subgraph object upon
  164. /// successful return.
  165. enum xnn_status xnn_create_subgraph(
  166. uint32_t external_value_ids,
  167. uint32_t flags,
  168. xnn_subgraph_t* subgraph_out);
  169. /// Destroy a Subgraph object, as well as Values, and Nodes associated with the subgraph.
  170. ///
  171. /// @param subgraph - the Subgraph object to destroy.
  172. enum xnn_status xnn_delete_subgraph(
  173. xnn_subgraph_t subgraph);
  174. #define XNN_VALUE_FLAG_EXTERNAL_INPUT 0x00000001
  175. #define XNN_VALUE_FLAG_EXTERNAL_OUTPUT 0x00000002
  176. #define XNN_VALUE_FLAG_PERSISTENT 0x00000004
  177. #define XNN_INVALID_VALUE_ID UINT32_MAX
  178. /// Type of elements in a Value object.
  179. enum xnn_datatype {
  180. /// Invalid data type. Valid Values never have this datatype.
  181. xnn_datatype_invalid = 0,
  182. /// IEEE754 single-precision floating-point.
  183. xnn_datatype_fp32 = 1,
  184. /// IEEE754 half-precision floating-point.
  185. xnn_datatype_fp16 = 2,
  186. /// Quantized 8-bit signed integer with shared per-Value quantization parameters.
  187. xnn_datatype_qint8 = 3,
  188. /// Quantized 8-bit unsigned integer with shared per-Value quantization parameters.
  189. xnn_datatype_quint8 = 4,
  190. /// Quantized 32-bit signed integer with shared per-Value quantization parameters.
  191. xnn_datatype_qint32 = 5,
  192. /// Quantized 8-bit signed integer with shared per-channel quantization parameters.
  193. xnn_datatype_qcint8 = 6,
  194. /// Quantized 32-bit signed integer with shared per-channel quantization parameters.
  195. xnn_datatype_qcint32 = 7,
  196. /// Quantized 4-bit signed integer with shared per-channel quantization parameters.
  197. xnn_datatype_qcint4 = 8,
  198. /// Dynamically quantized 8-bit signed integer with per-batch quantization parameters.
  199. xnn_datatype_qdint8 = 9,
  200. };
  201. /// Define a tensor-type Value and add it to a Subgraph.
  202. ///
  203. /// @param subgraph - a Subgraph object that will own the created Value.
  204. /// @param datatype - type of the tensor elements.
  205. /// @param num_dims - number of dimensions in the shape.
  206. /// @param dims - pointer to an array of @a num_dims shape dimensions. If num_dims is 0, this pointer can be NULL.
  207. /// XNNPACK does not keep any pointers to this array after the function returns.
  208. /// @param data - pointer to static data used for tensor initialization. If the tensor is not statically initialized,
  209. /// this pointer must be is NULL. If non-NULL, the life-time of the static data must exceed the life-time
  210. /// of the Subgraph object, and of any Runtime objects created from the Subgraph.
  211. /// @param external_id - external ID for the Value. The ID must be within the range of reversed Value IDs specified on
  212. /// the Subgraph creation. If the external ID is XNN_INVALID_VALUE_ID, an internal ID will be
  213. /// created for the Value.
  214. /// @param flags - binary features of the Value. Supported values are any combination of XNN_VALUE_FLAG_EXTERNAL_INPUT
  215. /// and XNN_VALUE_FLAG_EXTERNAL_OUTPUT.
  216. /// @param id_out - pointer to the variable that will be initialized with the Value ID upon successful return. If a
  217. /// valid @a external_id was provided, the variable will be initialized with the @a external_id value.
  218. enum xnn_status xnn_define_tensor_value(
  219. xnn_subgraph_t subgraph,
  220. enum xnn_datatype datatype,
  221. size_t num_dims,
  222. const size_t* dims,
  223. const void* data,
  224. uint32_t external_id,
  225. uint32_t flags,
  226. uint32_t* id_out);
  227. /// Define a quantized tensor-type Value and add it to a Subgraph.
  228. ///
  229. /// @param subgraph - a Subgraph object that will own the created Value.
  230. /// @param datatype - type of the tensor elements.
  231. /// @param zero_point - offset from zero to subtract from the quantized elements in the Value.
  232. /// @param scale - multiplication factor to convert quantized elements to real representation.
  233. /// @param num_dims - number of dimensions in the shape.
  234. /// @param dims - pointer to an array of @a num_dims shape dimensions. If num_dims is 0, this pointer can be NULL.
  235. /// XNNPACK does not keep any pointers to this array after the function returns.
  236. /// @param data - pointer to static data used for tensor initialization. If the tensor is not statically initialized,
  237. /// this pointer must be is NULL. If non-NULL, the life-time of the static data must exceed the life-time
  238. /// of the Subgraph object, and of any Runtime objects created from the Subgraph.
  239. /// @param external_id - external ID for the Value. The ID must be within the range of reversed Value IDs specified on
  240. /// the Subgraph creation. If the external ID is XNN_INVALID_VALUE_ID, an internal ID will be
  241. /// created for the Value.
  242. /// @param flags - binary features of the Value. Supported values are any combination of XNN_VALUE_FLAG_EXTERNAL_INPUT
  243. /// and XNN_VALUE_FLAG_EXTERNAL_OUTPUT.
  244. /// @param id_out - pointer to the variable that will be initialized with the Value ID upon successful return. If a
  245. /// valid @a external_id was provided, the variable will be initialized with the @a external_id value.
  246. enum xnn_status xnn_define_quantized_tensor_value(
  247. xnn_subgraph_t subgraph,
  248. enum xnn_datatype datatype,
  249. int32_t zero_point,
  250. float scale,
  251. size_t num_dims,
  252. const size_t* dims,
  253. const void* data,
  254. uint32_t external_id,
  255. uint32_t flags,
  256. uint32_t* id_out);
  257. enum xnn_status xnn_define_channelwise_quantized_tensor_value(
  258. xnn_subgraph_t subgraph,
  259. enum xnn_datatype datatype,
  260. const float* scale,
  261. size_t num_dims,
  262. size_t channel_dim,
  263. const size_t* dims,
  264. const void* data,
  265. uint32_t external_id,
  266. uint32_t flags,
  267. uint32_t* id_out);
  268. /// Validate the dimensions, channel_dim, zero point, datatype, and scale of a quantized tensor-type.
  269. ///
  270. /// @param datatype - type of the tensor elements.
  271. /// @param zero_point - offset from zero to subtract from the quantized elements in the Value.
  272. /// @param scale - multiplication factor to convert quantized elements to real representation.
  273. /// @param num_dims - number of dimensions in the shape.
  274. /// @param dims - pointer to an array of @a num_dims shape dimensions. If num_dims is 0, this pointer can be NULL.
  275. /// XNNPACK does not keep any pointers to this array after the function returns.
  276. enum xnn_status xnn_validate_quantized_tensor(
  277. enum xnn_datatype datatype,
  278. int32_t zero_point,
  279. float scale,
  280. size_t num_dims,
  281. const size_t* dims);
  282. /// Validate the dimensions, channel_dim, zero point, datatype, and scales of a channelwise quantized tensor-type.
  283. ///
  284. /// @param datatype - type of the tensor elements.
  285. /// @param zero_point - offset from zero to subtract from the quantized elements in the Value.
  286. /// @param scale - per-channel multiplication factors to convert quantized elements to real representation.
  287. /// @param num_dims - number of dimensions in the shape.
  288. /// @param channel_dim - index of the channel dimension in the tensor with per-channel quantization parameters.
  289. /// Typically this is the first dimension (dimension #0) of the filter tensors in the Convolution,
  290. /// Deconvolution, and Fully Connected operators and the last dimension of the filter tensors in
  291. /// the Depthwise Convolution operators.
  292. /// @param dims - pointer to an array of @a num_dims shape dimensions. If num_dims is 0, this pointer can be NULL.
  293. /// XNNPACK does not keep any pointers to this array after the function returns.
  294. enum xnn_status xnn_validate_channelwise_quantized_tensor(
  295. enum xnn_datatype datatype,
  296. int32_t zero_point,
  297. const float* scale,
  298. size_t num_dims,
  299. size_t channel_dim,
  300. const size_t* dims);
  301. /// Define a channelwise quantized tensor-type Value and add it to a Subgraph.
  302. ///
  303. /// @param subgraph - a Subgraph object that will own the created Value.
  304. /// @param datatype - type of the tensor elements.
  305. /// @param zero_point - offset from zero to subtract from the quantized elements in the Value.
  306. /// @param scale - per-channel multiplication factors to convert quantized elements to real representation.
  307. /// @param num_dims - number of dimensions in the shape.
  308. /// @param channel_dim - index of the channel dimension in the tensor with per-channel quantization parameters.
  309. /// Typically this is the first dimension (dimension #0) of the filter tensors in the Convolution,
  310. /// Deconvolution, and Fully Connected operators and the last dimension of the filter tensors in
  311. /// the Depthwise Convolution operators.
  312. /// @param dims - pointer to an array of @a num_dims shape dimensions. If num_dims is 0, this pointer can be NULL.
  313. /// XNNPACK does not keep any pointers to this array after the function returns.
  314. /// @param data - pointer to static data used for tensor initialization. If the tensor is not statically initialized,
  315. /// this pointer must be is NULL. If non-NULL, the life-time of the static data must exceed the life-time
  316. /// of the Subgraph object, and of any Runtime objects created from the Subgraph.
  317. /// @param external_id - external ID for the Value. The ID must be within the range of reversed Value IDs specified on
  318. /// the Subgraph creation. If the external ID is XNN_INVALID_VALUE_ID, an internal ID will be
  319. /// created for the Value.
  320. /// @param flags - binary features of the Value. Supported values are any combination of XNN_VALUE_FLAG_EXTERNAL_INPUT
  321. /// and XNN_VALUE_FLAG_EXTERNAL_OUTPUT.
  322. /// @param id_out - pointer to the variable that will be initialized with the Value ID upon successful return. If a
  323. /// valid @a external_id was provided, the variable will be initialized with the @a external_id value.
  324. enum xnn_status xnn_define_channelwise_quantized_tensor_value_v2(
  325. xnn_subgraph_t subgraph,
  326. enum xnn_datatype datatype,
  327. int32_t zero_point,
  328. const float* scale,
  329. size_t num_dims,
  330. size_t channel_dim,
  331. const size_t* dims,
  332. const void* data,
  333. uint32_t external_id,
  334. uint32_t flags,
  335. uint32_t* id_out);
  336. /// Define a dynamically quantized tensor-type Value and add it to a Subgraph.
  337. ///
  338. /// @param subgraph - a Subgraph object that will own the created Value.
  339. /// @param datatype - type of the tensor elements.
  340. /// @param num_dims - number of dimensions in the shape.
  341. /// @param num_non_batch_dims - number of non-batch dimensions in the shape. The leading (num_dims - num_non_batch_dims)
  342. /// dimensions will be flattened and treated as batch size. A set of quantization parameters
  343. /// will be calculated for each batch element.
  344. /// @param dims - pointer to an array of @a num_dims shape dimensions. If num_dims is 0, this pointer can be NULL.
  345. /// XNNPACK does not keep any pointers to this array after the function returns.
  346. /// @param external_id - external ID for the Value. The ID must be within the range of reversed Value IDs specified on
  347. /// the Subgraph creation. If the external ID is XNN_INVALID_VALUE_ID, an internal ID will be
  348. /// created for the Value.
  349. /// @param flags - binary features of the Value. No supported flags are currently defined.
  350. /// @param id_out - pointer to the variable that will be initialized with the Value ID upon successful return. If a
  351. /// valid @a external_id was provided, the variable will be initialized with the @a external_id value.
  352. enum xnn_status xnn_define_dynamically_quantized_tensor_value(
  353. xnn_subgraph_t subgraph,
  354. enum xnn_datatype datatype,
  355. size_t num_dims,
  356. size_t num_nonbatch_dims,
  357. const size_t* dims,
  358. uint32_t external_id,
  359. uint32_t flags,
  360. uint32_t* id_out);
  361. /// Define a Convert Node and add it to a Subgraph.
  362. ///
  363. /// @param subgraph - a Subgraph object that will own the created Node.
  364. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  365. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  366. /// shape must match the shape of the input tensor.
  367. /// @param flags - binary features of the Convert Node. No supported flags are currently defined.
  368. enum xnn_status xnn_define_convert(
  369. xnn_subgraph_t subgraph,
  370. uint32_t input_id,
  371. uint32_t output_id,
  372. uint32_t flags);
  373. /// Define a 2D Convolution Node and add it to a Subgraph.
  374. ///
  375. /// @param subgraph - a Subgraph object that will own the created Node.
  376. /// @param input_padding_top - implicit zero-padding above 2D input data. Must be 0 if XNN_FLAG_TENSORFLOW_SAME_PADDING
  377. /// flag is specified.
  378. /// @param input_padding_right - implicit zero-padding to the right of 2D input data. Must be 0 if
  379. /// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified.
  380. /// @param input_padding_bottom - implicit zero-padding below 2D input data. Must be 0 if
  381. /// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified.
  382. /// @param input_padding_left - implicit zero-padding to the left of 2D input data. Must be 0 if
  383. /// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified.
  384. /// @param kernel_height - kernel (filter) height.
  385. /// @param kernel_width - kernel (filter) width.
  386. /// @param subsampling_height - height of subsampling region for convolution output (convolution height stride).
  387. /// @param subsampling_width - width of subsampling region for convolution output (convolution width stride).
  388. /// @param dilation_height - dilation of kernel elements along the height dimension.
  389. /// @param dilation_width - dilation of kernel elements along the width dimension.
  390. /// @param groups - number of convolution groups.
  391. /// @param group_input_channels - number of input channels per group.
  392. /// @param group_output_channels - number of output channels per group.
  393. /// @param output_min - lower bound for clipping output values.
  394. /// @param output_max - upper bound for clipping output values.
  395. /// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph
  396. /// with [N, IH, IW, groups * group_input_channels] dimensions
  397. /// @param filter_id - Value ID for the filter tensor. The filter tensor must ge a 4D tensor defined in the @a subgraph
  398. /// with [groups * group_output_channels, kernel_height, kernel_width, group_input_channels]
  399. /// dimensions.
  400. /// @param bias_id - Value ID for the bias tensor, or XNN_INVALID_VALUE_ID for a 2D Convolution Node without a bias. If
  401. /// present, the bias tensor must be a 1D tensor defined in the @a subgraph with [groups *
  402. /// group_output_channels] dimensions.
  403. /// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph
  404. /// with [N, OH, OW, groups * group_output_channels] dimensions.
  405. /// @param flags - binary features of the 2D Convolution Node. The only currently supported values is
  406. /// XNN_FLAG_TENSORFLOW_SAME_PADDING.
  407. enum xnn_status xnn_define_convolution_2d(
  408. xnn_subgraph_t subgraph,
  409. uint32_t input_padding_top,
  410. uint32_t input_padding_right,
  411. uint32_t input_padding_bottom,
  412. uint32_t input_padding_left,
  413. uint32_t kernel_height,
  414. uint32_t kernel_width,
  415. uint32_t subsampling_height,
  416. uint32_t subsampling_width,
  417. uint32_t dilation_height,
  418. uint32_t dilation_width,
  419. uint32_t groups,
  420. size_t group_input_channels,
  421. size_t group_output_channels,
  422. float output_min,
  423. float output_max,
  424. uint32_t input_id,
  425. uint32_t filter_id,
  426. uint32_t bias_id,
  427. uint32_t output_id,
  428. uint32_t flags);
  429. /// Define a 2D Deconvolution (Transposed Convolution) Node and add it to a Subgraph.
  430. ///
  431. /// @param subgraph - a Subgraph object that will own the created Node.
  432. /// @param padding_top - implicit padding above 2D output data.
  433. /// @param padding_right - implicit padding to the right of 2D output data.
  434. /// @param padding_bottom - implicit padding below 2D output data.
  435. /// @param padding_left - implicit padding to the left of 2D output data.
  436. /// @param adjustment_height - additional elements in the bottom of the 2D output data.
  437. /// @param adjustment_width - additional elements to the right of the 2D output data.
  438. /// @param kernel_height - kernel (filter) height.
  439. /// @param kernel_width - kernel (filter) width.
  440. /// @param upsampling_height - height of upsampling region for deconvolution input (deconvolution height stride).
  441. /// @param upsampling_width - width of upsampling region for deconvolution input (deconvolution width stride).
  442. /// @param dilation_height - dilation of kernel elements along the height dimension.
  443. /// @param dilation_width - dilation of kernel elements along the width dimension.
  444. /// @param groups - number of convolution groups.
  445. /// @param group_input_channels - number of input channels per group.
  446. /// @param group_output_channels - number of output channels per group.
  447. /// @param output_min - lower bound for clipping output values.
  448. /// @param output_max - upper bound for clipping output values.
  449. /// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph
  450. /// with [N, IH, IW, groups * group_input_channels] dimensions
  451. /// @param filter_id - Value ID for the filter tensor. The filter tensor must ge a 4D tensor defined in the @a subgraph
  452. /// with [groups * group_output_channels, kernel_height, kernel_width, group_input_channels]
  453. /// dimensions.
  454. /// @param bias_id - Value ID for the bias tensor, or XNN_INVALID_VALUE_ID for a 2D Convolution Node without a bias. If
  455. /// present, the bias tensor must be a 1D tensor defined in the @a subgraph with
  456. /// [groups * group_output_channels] dimensions.
  457. /// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph
  458. /// with [N, OH, OW, groups * group_output_channels] dimensions.
  459. /// @param flags - binary features of the 2D Deconvolution Node. No supported flags are currently defined.
  460. enum xnn_status xnn_define_deconvolution_2d(
  461. xnn_subgraph_t subgraph,
  462. uint32_t padding_top,
  463. uint32_t padding_right,
  464. uint32_t padding_bottom,
  465. uint32_t padding_left,
  466. uint32_t adjustment_height,
  467. uint32_t adjustment_width,
  468. uint32_t kernel_height,
  469. uint32_t kernel_width,
  470. uint32_t upsampling_height,
  471. uint32_t upsampling_width,
  472. uint32_t dilation_height,
  473. uint32_t dilation_width,
  474. uint32_t groups,
  475. size_t group_input_channels,
  476. size_t group_output_channels,
  477. float output_min,
  478. float output_max,
  479. uint32_t input_id,
  480. uint32_t filter_id,
  481. uint32_t bias_id,
  482. uint32_t output_id,
  483. uint32_t flags);
  484. /// Define a 2D Depthwise Convolution Node and add it to a Subgraph.
  485. ///
  486. /// @param subgraph - a Subgraph object that will own the created Node.
  487. /// @param input_padding_top - implicit zero-padding above 2D input data. Must be 0 if XNN_FLAG_TENSORFLOW_SAME_PADDING
  488. /// flag is specified.
  489. /// @param input_padding_right - implicit zero-padding to the right of 2D input data. Must be 0 if
  490. /// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified.
  491. /// @param input_padding_bottom - implicit zero-padding below 2D input data. Must be 0 if
  492. /// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified.
  493. /// @param input_padding_left - implicit zero-padding to the left of 2D input data. Must be 0 if
  494. /// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified.
  495. /// @param kernel_height - kernel (filter) height.
  496. /// @param kernel_width - kernel (filter) width.
  497. /// @param subsampling_height - height of subsampling region for convolution output (convolution height stride).
  498. /// @param subsampling_width - width of subsampling region for convolution output (convolution width stride).
  499. /// @param dilation_height - dilation of kernel elements along the height dimension.
  500. /// @param dilation_width - dilation of kernel elements along the width dimension.
  501. /// @param depth_multiplier - ratio of output channels to input channels.
  502. /// @param input_channels - number of input channels.
  503. /// @param output_min - lower bound for clipping output values.
  504. /// @param output_max - upper bound for clipping output values.
  505. /// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph
  506. /// with [N, IH, IW, input_channels] dimensions
  507. /// @param filter_id - Value ID for the filter tensor. The filter tensor must ge a 4D tensor defined in the @a subgraph
  508. /// with [1, kernel_height, kernel_width, input_channels * depth_multiplier] dimensions.
  509. /// @param bias_id - Value ID for the bias tensor, or XNN_INVALID_VALUE_ID for a 2D Depthwise Convolution Node without
  510. /// a bias. If present, the bias tensor must be a 1D tensor defined in the @a subgraph with
  511. /// [input_channels * depth_multiplier] dimensions.
  512. /// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph
  513. /// with [N, OH, OW, input_channels * depth_multiplier] dimensions.
  514. /// @param flags - binary features of the 2D Depthwise Convolution Node. The only currently supported values is
  515. /// XNN_FLAG_TENSORFLOW_SAME_PADDING.
  516. enum xnn_status xnn_define_depthwise_convolution_2d(
  517. xnn_subgraph_t subgraph,
  518. uint32_t input_padding_top,
  519. uint32_t input_padding_right,
  520. uint32_t input_padding_bottom,
  521. uint32_t input_padding_left,
  522. uint32_t kernel_height,
  523. uint32_t kernel_width,
  524. uint32_t subsampling_height,
  525. uint32_t subsampling_width,
  526. uint32_t dilation_height,
  527. uint32_t dilation_width,
  528. uint32_t depth_multiplier,
  529. size_t input_channels,
  530. float output_min,
  531. float output_max,
  532. uint32_t input_id,
  533. uint32_t filter_id,
  534. uint32_t bias_id,
  535. uint32_t output_id,
  536. uint32_t flags);
  537. /// Define a Depth To Space Node 2D and add it to a Subgraph.
  538. ///
  539. /// The Depth To Space 2D Node rearranges data from depth into blocks of spatial data (a reverse transform to
  540. /// Space To Depth). For a given input pixel, an output square of pixels with side @a block_size is formed from values
  541. /// in the corresponding number of its channels. The output depth is therefore @a block_size x @a block_size times
  542. /// smaller than that of the input.
  543. ///
  544. /// @param subgraph - a Subgraph object that will own the created Node.
  545. /// @param block_size - the size of the spatial block.
  546. /// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph
  547. /// with [N, IH, IW, OC * block_size * block_size] dimensions.
  548. /// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph
  549. /// with [N, IH * block_size, IW * block_size, OC] dimensions.
  550. /// @param flags - binary features of the input_channels Node. No supported flags are currently defined.
  551. enum xnn_status xnn_define_depth_to_space_2d(
  552. xnn_subgraph_t subgraph,
  553. uint32_t block_size,
  554. uint32_t input_id,
  555. uint32_t output_id,
  556. uint32_t flags);
  557. enum xnn_status xnn_define_depth_to_space(
  558. xnn_subgraph_t subgraph,
  559. uint32_t input_id,
  560. uint32_t output_id,
  561. uint32_t block_size,
  562. uint32_t flags);
  563. /// Define a 1D Global Average Pooling Node and add it to a Subgraph.
  564. ///
  565. /// @param subgraph - a Subgraph object that will own the created Node.
  566. /// @param output_min - lower bound for clipping output values.
  567. /// @param output_max - upper bound for clipping output values.
  568. /// @param input_id - Value ID for the input tensor. The input tensor must be a dense tensor with 2 or more dimensions
  569. /// defined in the @a subgraph. Averaging is performed across the second-innermost dimension.
  570. /// @param output_id - Value ID for the output tensor. The output tensor must be a dense tensor with 2 or more
  571. /// dimensions defined in the @a subgraph.
  572. /// @param flags - binary features of the 1D Global Average Pooling Node. The only currently supported value is
  573. /// XNN_FLAG_REDUCE_DIMS.
  574. enum xnn_status xnn_define_global_average_pooling_1d(
  575. xnn_subgraph_t subgraph,
  576. float output_min,
  577. float output_max,
  578. uint32_t input_id,
  579. uint32_t output_id,
  580. uint32_t flags);
  581. /// Define a 2D Global Average Pooling Node and add it to a Subgraph.
  582. ///
  583. /// @param subgraph - a Subgraph object that will own the created Node.
  584. /// @param output_min - lower bound for clipping output values.
  585. /// @param output_max - upper bound for clipping output values.
  586. /// @param input_id - Value ID for the input tensor. The input tensor must be a dense tensor with 3 or more dimensions
  587. /// defined in the @a subgraph. Averaging is performed across the second- and third-innermost
  588. /// dimensions.
  589. /// @param output_id - Value ID for the output tensor. The output tensor must be a dense tensor with 3 or more
  590. /// dimensions defined in the @a subgraph.
  591. /// @param flags - binary features of the 2D Global Average Pooling Node. The only currently supported value is
  592. /// XNN_FLAG_REDUCE_DIMS.
  593. enum xnn_status xnn_define_global_average_pooling_2d(
  594. xnn_subgraph_t subgraph,
  595. float output_min,
  596. float output_max,
  597. uint32_t input_id,
  598. uint32_t output_id,
  599. uint32_t flags);
  600. /// Define a 1D Global Sum Pooling Node and add it to a Subgraph.
  601. ///
  602. /// @param subgraph - a Subgraph object that will own the created Node.
  603. /// @param output_min - lower bound for clipping output values.
  604. /// @param output_max - upper bound for clipping output values.
  605. /// @param input_id - Value ID for the input tensor. The input tensor must be a dense tensor with 2 or more dimensions
  606. /// defined in the @a subgraph. Averaging is performed across the second-innermost dimension.
  607. /// @param output_id - Value ID for the output tensor. The output tensor must be a dense tensor with 2 or more
  608. /// dimensions defined in the @a subgraph.
  609. /// @param flags - binary features of the 1D Global Sum Pooling Node. The only currently supported value is
  610. /// XNN_FLAG_REDUCE_DIMS.
  611. enum xnn_status xnn_define_global_sum_pooling_1d(
  612. xnn_subgraph_t subgraph,
  613. float output_min,
  614. float output_max,
  615. uint32_t input_id,
  616. uint32_t output_id,
  617. uint32_t flags);
  618. /// Define a 2D Global Sum Pooling Node and add it to a Subgraph.
  619. ///
  620. /// @param subgraph - a Subgraph object that will own the created Node.
  621. /// @param output_min - lower bound for clipping output values.
  622. /// @param output_max - upper bound for clipping output values.
  623. /// @param input_id - Value ID for the input tensor. The input tensor must be a dense tensor with 3 or more dimensions
  624. /// defined in the @a subgraph. Averaging is performed across the second- and third-innermost
  625. /// dimensions.
  626. /// @param output_id - Value ID for the output tensor. The output tensor must be a dense tensor with 3 or more
  627. /// dimensions defined in the @a subgraph.
  628. /// @param flags - binary features of the 2D Global Sum Pooling Node. The only currently supported value is
  629. /// XNN_FLAG_REDUCE_DIMS.
  630. enum xnn_status xnn_define_global_sum_pooling_2d(
  631. xnn_subgraph_t subgraph,
  632. float output_min,
  633. float output_max,
  634. uint32_t input_id,
  635. uint32_t output_id,
  636. uint32_t flags);
  637. /// Define a 2D Average Pooling Node and add it to a Subgraph.
  638. ///
  639. /// @param subgraph - a Subgraph object that will own the created Node.
  640. /// @param input_padding_top - implicit zero-padding above 2D input data. Must be 0 if XNN_FLAG_TENSORFLOW_SAME_PADDING
  641. /// flag is specified.
  642. /// @param input_padding_right - implicit zero-padding to the right of 2D input data. Must be 0 if
  643. /// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified.
  644. /// @param input_padding_bottom - implicit zero-padding below 2D input data. Must be 0 if
  645. /// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified.
  646. /// @param input_padding_left - implicit zero-padding to the left of 2D input data. Must be 0 if
  647. /// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified.
  648. /// @param pooling_height - pooling (kernel) height.
  649. /// @param pooling_width - pooling (kernel) width.
  650. /// @param stride_height - displacing of the pooling window in the vertical dimension of the input pixels corresponding
  651. /// to vertically adjacent output pixels.
  652. /// @param stride_width - displacing of the pooling window in the horizontal dimension of the input pixels corresponding
  653. /// to horizontally adjacent output pixels.
  654. /// @param output_min - lower bound for clipping output values.
  655. /// @param output_max - upper bound for clipping output values.
  656. /// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph
  657. /// with [N, IH, IW, channels] dimensions
  658. /// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph
  659. /// with [N, OH, OW, channels] dimensions.
  660. /// @param flags - binary features of the 2D Average Pooling Node. The only currently supported values is
  661. /// XNN_FLAG_TENSORFLOW_SAME_PADDING.
  662. enum xnn_status xnn_define_average_pooling_2d(
  663. xnn_subgraph_t subgraph,
  664. uint32_t input_padding_top,
  665. uint32_t input_padding_right,
  666. uint32_t input_padding_bottom,
  667. uint32_t input_padding_left,
  668. uint32_t pooling_height,
  669. uint32_t pooling_width,
  670. uint32_t stride_height,
  671. uint32_t stride_width,
  672. float output_min,
  673. float output_max,
  674. uint32_t input_id,
  675. uint32_t output_id,
  676. uint32_t flags);
  677. /// Define a Fully Connected Node and add it to a Subgraph.
  678. ///
  679. /// @param subgraph - a Subgraph object that will own the created Node.
  680. /// @param output_min - lower bound for clipping output values.
  681. /// @param output_max - upper bound for clipping output values.
  682. /// @param input_id - Value ID for the input tensor. The input tensor must be an N-dimensional tensor defined in the
  683. /// @a subgraph. If XNN_FLAG_TENSORFLOW_RESHAPE_2D is not specified, the input tensor must be at least
  684. /// 1D and its last dimension must match the last dimension of the filter tensor. In particular, if
  685. /// input is a 2D tensor, it must have [batch_size, input_channels] dimensions.
  686. /// If XNN_FLAG_TENSORFLOW_RESHAPE_2D is specified, the number of elements in the input tensor must be
  687. /// divisible by the input_channels. The tensor will be first flattened into a 1D tensor of
  688. /// [num_input_elements] dimensions, then reshaped into a 2D tensor of
  689. /// [num_input_elements / input_channels, input_channels] dimensions where num_input_elements is the
  690. /// total number of elements in the input tensor.
  691. /// @param filter_id - Value ID for the filter tensor. The filter tensor must a 2D tensor defined in the @a subgraph.
  692. /// If the XNN_FLAG_TRANSPOSE_WEIGHTS flag is not specified, the filter tensor must have
  693. /// [output_channels, input_channels] dimensions. If the XNN_FLAG_TRANSPOSE_WEIGHTS flag is
  694. /// specified, the filter tensor must have [input_channels, output_channels] dimensions.
  695. /// @param bias_id - Value ID for the bias tensor, or XNN_INVALID_VALUE_ID for a Fully Connected Node without a bias.
  696. /// If present, the bias tensor must be a 1D tensor defined in the @a subgraph with [output_channels]
  697. /// dimensions.
  698. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph.
  699. /// If XNN_FLAG_TENSORFLOW_RESHAPE_2D is not specified, the output tensor must have the same
  700. /// dimensionality as the input tensor, all its dimensions but the last one must match the
  701. /// corresponding dimensions of the input tensor, and the last dimensions of the output tensor must
  702. /// match the first dimension of the filter tensor. In particular, if input is a 2D tensor, output
  703. /// must be a 2D tensor of [batch_size, output_channels] dimensions.
  704. /// If XNN_FLAG_TENSORFLOW_RESHAPE_2D is specified, output must be a 2D tensor of
  705. /// [num_input_elements / input_channels, output_channels] dimensions where num_input_elements is the
  706. /// total number of elements in the input tensor.
  707. /// @param flags - binary features of the Fully Connected Node. The only currently supported values are
  708. /// XNN_FLAG_TENSORFLOW_RESHAPE_2D and XNN_FLAG_TRANSPOSE_WEIGHTS.
  709. enum xnn_status xnn_define_fully_connected(
  710. xnn_subgraph_t subgraph,
  711. float output_min,
  712. float output_max,
  713. uint32_t input_id,
  714. uint32_t filter_id,
  715. uint32_t bias_id,
  716. uint32_t output_id,
  717. uint32_t flags);
  718. /// Define a Sparse Fully Connected Node and add it to a Subgraph.
  719. ///
  720. /// This operator is experimental, and will be removed in the future.
  721. ///
  722. /// @param subgraph - a Subgraph object that will own the created Node.
  723. /// @param output_min - lower bound for clipping output values.
  724. /// @param output_max - upper bound for clipping output values.
  725. /// @param input_id - Value ID for the input tensor. The input tensor must be an N-dimensional tensor defined in the
  726. /// @a subgraph. If XNN_FLAG_TENSORFLOW_RESHAPE_2D is not specified, the input tensor must be at least
  727. /// 1D and its last dimension must match the last dimension of the filter tensor. In particular, if
  728. /// input is a 2D tensor, it must have [batch_size, input_channels] dimensions.
  729. /// If XNN_FLAG_TENSORFLOW_RESHAPE_2D is specified, the number of elements in the input tensor must be
  730. /// divisible by the input_channels. The tensor will be first flattened into a 1D tensor of
  731. /// [num_input_elements] dimensions, then reshaped into a 2D tensor of
  732. /// [num_input_elements / input_channels, input_channels] dimensions where num_input_elements is the
  733. /// total number of elements in the input tensor.
  734. /// @param filter_id - Value ID for the filter tensor. The filter tensor must a 2D tensor defined in the @a subgraph.
  735. /// If the XNN_FLAG_TRANSPOSE_WEIGHTS flag is not specified, the filter tensor must have
  736. /// [output_channels, input_channels] dimensions. If the XNN_FLAG_TRANSPOSE_WEIGHTS flag is
  737. /// specified, the filter tensor must have [input_channels, output_channels] dimensions.
  738. /// @param bias_id - Value ID for the bias tensor, or XNN_INVALID_VALUE_ID for a Fully Connected Node without a bias.
  739. /// If present, the bias tensor must be a 1D tensor defined in the @a subgraph with [output_channels]
  740. /// dimensions.
  741. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph.
  742. /// If XNN_FLAG_TENSORFLOW_RESHAPE_2D is not specified, the output tensor must have the same
  743. /// dimensionality as the input tensor, all its dimensions but the last one must match the
  744. /// corresponding dimensions of the input tensor, and the last dimensions of the output tensor must
  745. /// match the first dimension of the filter tensor. In particular, if input is a 2D tensor, output
  746. /// must be a 2D tensor of [batch_size, output_channels] dimensions.
  747. /// If XNN_FLAG_TENSORFLOW_RESHAPE_2D is specified, output must be a 2D tensor of
  748. /// [num_input_elements / input_channels, output_channels] dimensions where num_input_elements is the
  749. /// total number of elements in the input tensor.
  750. /// @param flags - binary features of the Fully Connected Node. The only currently supported values are
  751. /// XNN_FLAG_TENSORFLOW_RESHAPE_2D and XNN_FLAG_TRANSPOSE_WEIGHTS.
  752. enum xnn_status xnn_define_fully_connected_sparse(
  753. xnn_subgraph_t subgraph,
  754. float output_min,
  755. float output_max,
  756. uint32_t input_id,
  757. uint32_t filter_id,
  758. uint32_t bias_id,
  759. uint32_t output_id,
  760. uint32_t flags);
  761. /// Define a 2D Max Pooling Node and add it to a Subgraph.
  762. ///
  763. /// @param subgraph - a Subgraph object that will own the created Node.
  764. /// @param input_padding_top - implicit zero-padding above 2D input data. Must be 0 if XNN_FLAG_TENSORFLOW_SAME_PADDING
  765. /// flag is specified.
  766. /// @param input_padding_right - implicit zero-padding to the right of 2D input data. Must be 0 if
  767. /// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified.
  768. /// @param input_padding_bottom - implicit zero-padding below 2D input data. Must be 0 if
  769. /// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified.
  770. /// @param input_padding_left - implicit zero-padding to the left of 2D input data. Must be 0 if
  771. /// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified.
  772. /// @param pooling_height - pooling (kernel) height.
  773. /// @param pooling_width - pooling (kernel) width.
  774. /// @param stride_height - displacing of the pooling window in the vertical dimension of the input pixels corresponding
  775. /// to vertically adjacent output pixels.
  776. /// @param stride_width - displacing of the pooling window in the horizontal dimension of the input pixels corresponding
  777. /// to horizontally adjacent output pixels.
  778. /// @param dilation_height - dilation of pooling elements along the height dimension.
  779. /// @param dilation_width - dilation of pooling elements along the width dimension.
  780. /// @param output_min - lower bound for clipping output values.
  781. /// @param output_max - upper bound for clipping output values.
  782. /// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph
  783. /// with [N, IH, IW, channels] dimensions
  784. /// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph
  785. /// with [N, OH, OW, channels] dimensions.
  786. /// @param flags - binary features of the 2D Max Pooling Node. The only currently supported values is
  787. /// XNN_FLAG_TENSORFLOW_SAME_PADDING.
  788. enum xnn_status xnn_define_max_pooling_2d(
  789. xnn_subgraph_t subgraph,
  790. uint32_t input_padding_top,
  791. uint32_t input_padding_right,
  792. uint32_t input_padding_bottom,
  793. uint32_t input_padding_left,
  794. uint32_t pooling_height,
  795. uint32_t pooling_width,
  796. uint32_t stride_height,
  797. uint32_t stride_width,
  798. uint32_t dilation_height,
  799. uint32_t dilation_width,
  800. float output_min,
  801. float output_max,
  802. uint32_t input_id,
  803. uint32_t output_id,
  804. uint32_t flags);
  805. /// Define a 2D ArgMax Pooling Node and add it to a Subgraph.
  806. ///
  807. /// @param subgraph - a Subgraph object that will own the created Node.
  808. /// @param input_padding_top - implicit zero-padding above 2D input data.
  809. /// @param input_padding_right - implicit zero-padding to the right of 2D input data.
  810. /// @param input_padding_bottom - implicit zero-padding below 2D input data.
  811. /// @param input_padding_left - implicit zero-padding to the left of 2D input data.
  812. /// @param pooling_height - pooling (kernel) height. Vertical stride between pooling regions match this value.
  813. /// @param pooling_width - pooling (kernel) width. Horizontal stride between pooling regions match this value.
  814. /// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph
  815. /// with [N, IH, IW, channels] dimensions
  816. /// @param output_value_id - Value ID for the output tensor with the maximum values in the pools. The output tensor must
  817. /// be a 4D tensor defined in the @a subgraph with [N, OH, OW, channels] dimensions.
  818. /// @param output_index_id - Value ID for the output tensor with the indexes of the maximum values in the pools. The
  819. /// output tensor must be a 4D tensor defined in the @a subgraph with [N, OH, OW, channels]
  820. /// dimensions.
  821. /// @param flags - binary features of the 2D ArgMax Pooling Node. No supported flags are currently defined.
  822. enum xnn_status xnn_define_argmax_pooling_2d(
  823. xnn_subgraph_t subgraph,
  824. uint32_t input_padding_top,
  825. uint32_t input_padding_right,
  826. uint32_t input_padding_bottom,
  827. uint32_t input_padding_left,
  828. uint32_t pooling_height,
  829. uint32_t pooling_width,
  830. uint32_t input_id,
  831. uint32_t output_value_id,
  832. uint32_t output_index_id,
  833. uint32_t flags);
  834. /// Define a 2D UnPooling Node and add it to a Subgraph.
  835. ///
  836. /// @param subgraph - a Subgraph object that will own the created Node.
  837. /// @param padding_top - implicit padding above 2D output data.
  838. /// @param padding_right - implicit padding to the right of 2D output data.
  839. /// @param padding_bottom - implicit padding below 2D output data.
  840. /// @param padding_left - implicit padding to the left of 2D output data.
  841. /// @param pooling_height - height of the pooling window.
  842. /// @param pooling_width - width of the pooling window.
  843. /// @param input_value_id - Value ID for the input tensor with the max-pooling values to invert. The input value tensor
  844. /// must be a 4D tensor defined in the @a subgraph with [N, IH, IW, channels] dimensions.
  845. /// @param input_index_id - Value ID for the input tensor with the indices of the per-pool maximum values produced by
  846. /// a 2D UnPooling Node. The input tensor must be a 4D tensor defined in the @a subgraph with
  847. /// [N, IH, IW, channels] dimensions.
  848. /// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph
  849. /// with [N, OH, OW, channels] dimensions.
  850. /// @param flags - binary features of the 2D UnPooling Node. No supported flags are currently defined.
  851. enum xnn_status xnn_define_unpooling_2d(
  852. xnn_subgraph_t subgraph,
  853. uint32_t padding_top,
  854. uint32_t padding_right,
  855. uint32_t padding_bottom,
  856. uint32_t padding_left,
  857. uint32_t pooling_height,
  858. uint32_t pooling_width,
  859. uint32_t input_value_id,
  860. uint32_t input_index_id,
  861. uint32_t output_id,
  862. uint32_t flags);
  863. /// Define a 2-Input Add Node and add it to a Subgraph.
  864. ///
  865. /// The 2-Input Add Node computes elementwise addition of two tensor inputs with numpy broadcasting rules.
  866. ///
  867. /// @param subgraph - a Subgraph object that will own the created Node.
  868. /// @param output_min - lower bound for clipping output values.
  869. /// @param output_max - upper bound for clipping output values.
  870. /// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in
  871. /// the @a subgraph with each dimension either equal to the corresponding dimension of the second
  872. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  873. /// that dimension.
  874. /// @param input2_id - Value ID for the second input tensor. The input tensor must be an M-dimensional tensor defined in
  875. /// the @a subgraph with each dimension either equal to the corresponding dimension of the first
  876. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  877. /// that dimension.
  878. /// @param output_id - Value ID for the output tensor. The output tensor must be a max(N,M)-dimensional tensor defined
  879. /// in the @a subgraph with each dimension equal to the maximum between the corresponding dimension
  880. /// of the two inputs.
  881. /// @param flags - binary features of the Add Node. No supported flags are currently defined.
  882. enum xnn_status xnn_define_add2(
  883. xnn_subgraph_t subgraph,
  884. float output_min,
  885. float output_max,
  886. uint32_t input1_id,
  887. uint32_t input2_id,
  888. uint32_t output_id,
  889. uint32_t flags);
  890. /// Define a 2-Input Multiply Node and add it to a Subgraph.
  891. ///
  892. /// The 2-Input Multiply Node computes elementwise multiplication of two tensor inputs with numpy broadcasting rules.
  893. ///
  894. /// @param subgraph - a Subgraph object that will own the created Node.
  895. /// @param output_min - lower bound for clipping output values.
  896. /// @param output_max - upper bound for clipping output values.
  897. /// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in
  898. /// the @a subgraph with each dimension either equal to the corresponding dimension of the second
  899. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  900. /// that dimension.
  901. /// @param input2_id - Value ID for the second input tensor. The input tensor must be an M-dimensional tensor defined in
  902. /// the @a subgraph with each dimension either equal to the corresponding dimension of the first
  903. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  904. /// that dimension.
  905. /// @param output_id - Value ID for the output tensor. The output tensor must be a max(N,M)-dimensional tensor defined
  906. /// in the @a subgraph with each dimension equal to the maximum between the corresponding dimension
  907. /// of the two inputs.
  908. /// @param flags - binary features of the Multiply Node. No supported flags are currently defined.
  909. enum xnn_status xnn_define_multiply2(
  910. xnn_subgraph_t subgraph,
  911. float output_min,
  912. float output_max,
  913. uint32_t input1_id,
  914. uint32_t input2_id,
  915. uint32_t output_id,
  916. uint32_t flags);
  917. // Cap operations applied to logits (Q * K) of attention operator.
  918. enum xnn_attention_logits_cap_type {
  919. // No capping.
  920. xnn_attention_logits_cap_type_none = 0,
  921. // Cap the absolute values of logits by tanh: tanh(logits / cap) * cap
  922. xnn_attention_logits_cap_type_tanh
  923. };
  924. // Params when the cap type is xnn_attention_logits_cap_type_tanh.
  925. struct xnn_attention_logits_cap_tanh_params {
  926. float cap;
  927. };
  928. /// Define a Scaled Dot-Product Attention Node and add it to a Subgraph.
  929. ///
  930. /// This operator is experimental.
  931. ///
  932. /// The Scaled Dot-Product Attention Node computes a multi-head or multi-query scaled dot attention on the query, key,
  933. /// and value tensors.
  934. ///
  935. /// @param subgraph - a Subgraph object that will own the created Node.
  936. /// @param cap_type - type of cap to be applied to the logits.
  937. /// @param cap_params - parameters for the cap. Must be a pointer to xnn_attention_logits_cap_tanh_params if cap_type
  938. /// is xnn_attention_logits_cap_type_tanh.
  939. /// @param query_id - Value ID for the query tensor. The query tensor must be a 3+-dimensional tensor defined in the
  940. /// @a subgraph with the dimensions as [*, H, T, C], where H/T/C are the heads/tokens/channels, and *
  941. /// is the 0 or more dimensions treated as batch size.
  942. /// @param key_id - Value ID for the key tensor. The key tensor must be a 2+--dimensional tensor defined in the
  943. /// @a subgraph. It can have the same number of dimensions as the query, with the dimensions as
  944. /// [*, H, U, C] (multi-head), or have 1 less dimension than the query, with the dimensions as
  945. /// as [*, U, C] (multi-query, number of heads omitted implies single head), where H/U/C are the
  946. /// heads/key_value_tokens/channels, and * is the 0 or more dimensions treated as batch size. These
  947. /// batch size dimensions must be the same as query.
  948. /// @param value_id - Value ID for the value tensor. The value tensor must be a 2+--dimensional tensor defined in the
  949. /// @a subgraph. It can have the same number of dimensions as the query, with the dimensions as
  950. /// [*, H, U, D] (multi-head), or have 1 less dimension than the query, with the dimensions as
  951. /// as [*, U, D] (multi-query, number of heads omitted implies single head), where H/U/D are the
  952. /// heads/key_value_tokens/value_channels, and * is the 0 or more dimensions treated as batch size.
  953. /// These batch size dimensions must be the same as query and key.
  954. /// @param scale_id - Value ID for the scale tensor. The scale tensor must be a 1D tensor defined in the @a subgraph
  955. /// with [C] dimensions. The query tensor is multiplied with this scale tensor before the dot product
  956. /// with the key tensor.
  957. /// @param mask_id - Value ID for the mask tensor. The mask tensor must be a 2D tensor defined in the @a subgraph with
  958. /// [T, U] dimensions. The mask tensor is added to the logits (query dot value).
  959. /// @param output_id - Value ID for the output tensor. The output tensor must be a 3+-dimensional tensor defined in the
  960. /// @a subgraph with the dimensions as [*, H, T, D], where H/T/D are the heads/tokens/value_channels,
  961. /// and * is the 0 or more dimensions treated as batch size. These batch size dimensions must be the
  962. /// same as query, key, and value.
  963. /// @param flags - binary features of the Scaled Dot Product Attention Node. No supported flags are currently defined.
  964. enum xnn_status xnn_define_scaled_dot_product_attention(
  965. xnn_subgraph_t subgraph,
  966. enum xnn_attention_logits_cap_type cap_type,
  967. const void* cap_params,
  968. uint32_t query_id,
  969. uint32_t key_id,
  970. uint32_t value_id,
  971. uint32_t scale_id,
  972. uint32_t mask_id,
  973. uint32_t output_id,
  974. uint32_t flags);
  975. /// Define a Subtract Node and add it to a Subgraph.
  976. ///
  977. /// The Subtract Node computes elementwise subtraction of two tensor inputs with numpy broadcasting rules.
  978. ///
  979. /// @param subgraph - a Subgraph object that will own the created Node.
  980. /// @param output_min - lower bound for clipping output values.
  981. /// @param output_max - upper bound for clipping output values.
  982. /// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in
  983. /// the @a subgraph with each dimension either equal to the corresponding dimension of the second
  984. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  985. /// that dimension.
  986. /// @param input2_id - Value ID for the second input tensor. The input tensor must be an M-dimensional tensor defined in
  987. /// the @a subgraph with each dimension either equal to the corresponding dimension of the first
  988. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  989. /// that dimension.
  990. /// @param output_id - Value ID for the output tensor. The output tensor must be a max(N,M)-dimensional tensor defined
  991. /// in the @a subgraph with each dimension equal to the maximum between the corresponding dimension
  992. /// of the two inputs.
  993. /// @param flags - binary features of the Subtract Node. No supported flags are currently defined.
  994. enum xnn_status xnn_define_subtract(
  995. xnn_subgraph_t subgraph,
  996. float output_min,
  997. float output_max,
  998. uint32_t input1_id,
  999. uint32_t input2_id,
  1000. uint32_t output_id,
  1001. uint32_t flags);
  1002. /// Define a Divide Node and add it to a Subgraph.
  1003. ///
  1004. /// The Divide Node computes elementwise division of two tensor inputs with numpy broadcasting rules.
  1005. ///
  1006. /// @param subgraph - a Subgraph object that will own the created Node.
  1007. /// @param output_min - lower bound for clipping output values.
  1008. /// @param output_max - upper bound for clipping output values.
  1009. /// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in
  1010. /// the @a subgraph with each dimension either equal to the corresponding dimension of the second
  1011. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  1012. /// that dimension.
  1013. /// @param input2_id - Value ID for the second input tensor. The input tensor must be an M-dimensional tensor defined in
  1014. /// the @a subgraph with each dimension either equal to the corresponding dimension of the first
  1015. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  1016. /// that dimension.
  1017. /// @param output_id - Value ID for the output tensor. The output tensor must be a max(N,M)-dimensional tensor defined
  1018. /// in the @a subgraph with each dimension equal to the maximum between the corresponding dimension
  1019. /// of the two inputs.
  1020. /// @param flags - binary features of the Divide Node. No supported flags are currently defined.
  1021. enum xnn_status xnn_define_divide(
  1022. xnn_subgraph_t subgraph,
  1023. float output_min,
  1024. float output_max,
  1025. uint32_t input1_id,
  1026. uint32_t input2_id,
  1027. uint32_t output_id,
  1028. uint32_t flags);
  1029. /// Define a 2-Input Maximum Node and add it to a Subgraph.
  1030. ///
  1031. /// The 2-Input Maximum Node computes elementwise maximum of two tensor inputs with numpy broadcasting rules.
  1032. ///
  1033. /// @param subgraph - a Subgraph object that will own the created Node.
  1034. /// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in
  1035. /// the @a subgraph with each dimension either equal to the corresponding dimension of the second
  1036. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  1037. /// that dimension.
  1038. /// @param input2_id - Value ID for the second input tensor. The input tensor must be an M-dimensional tensor defined in
  1039. /// the @a subgraph with each dimension either equal to the corresponding dimension of the first
  1040. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  1041. /// that dimension.
  1042. /// @param output_id - Value ID for the output tensor. The output tensor must be a max(N,M)-dimensional tensor defined
  1043. /// in the @a subgraph with each dimension equal to the maximum between the corresponding dimension
  1044. /// of the two inputs.
  1045. /// @param flags - binary features of the Maximum Node. No supported flags are currently defined.
  1046. enum xnn_status xnn_define_maximum2(
  1047. xnn_subgraph_t subgraph,
  1048. uint32_t input1_id,
  1049. uint32_t input2_id,
  1050. uint32_t output_id,
  1051. uint32_t flags);
  1052. /// Define a 2-Input Minimum Node and add it to a Subgraph.
  1053. ///
  1054. /// The 2-Input Minimum Node computes elementwise minimum of two tensor inputs with numpy broadcasting rules.
  1055. ///
  1056. /// @param subgraph - a Subgraph object that will own the created Node.
  1057. /// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in
  1058. /// the @a subgraph with each dimension either equal to the corresponding dimension of the second
  1059. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  1060. /// that dimension.
  1061. /// @param input2_id - Value ID for the second input tensor. The input tensor must be an M-dimensional tensor defined in
  1062. /// the @a subgraph with each dimension either equal to the corresponding dimension of the first
  1063. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  1064. /// that dimension.
  1065. /// @param output_id - Value ID for the output tensor. The output tensor must be a max(N,M)-dimensional tensor defined
  1066. /// in the @a subgraph with each dimension equal to the maximum between the corresponding dimension
  1067. /// of the two inputs.
  1068. /// @param flags - binary features of the Minimum Node. No supported flags are currently defined.
  1069. enum xnn_status xnn_define_minimum2(
  1070. xnn_subgraph_t subgraph,
  1071. uint32_t input1_id,
  1072. uint32_t input2_id,
  1073. uint32_t output_id,
  1074. uint32_t flags);
  1075. /// Define a Squared Difference Node and add it to a Subgraph.
  1076. ///
  1077. /// The Squared Difference Node computes elementwise squared difference of two tensor inputs with numpy broadcasting
  1078. /// rules.
  1079. ///
  1080. /// @param subgraph - a Subgraph object that will own the created Node.
  1081. /// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in
  1082. /// the @a subgraph with each dimension either equal to the corresponding dimension of the second
  1083. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  1084. /// that dimension.
  1085. /// @param input2_id - Value ID for the second input tensor. The input tensor must be an M-dimensional tensor defined in
  1086. /// the @a subgraph with each dimension either equal to the corresponding dimension of the first
  1087. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  1088. /// that dimension.
  1089. /// @param output_id - Value ID for the output tensor. The output tensor must be a max(N,M)-dimensional tensor defined
  1090. /// in the @a subgraph with each dimension equal to the maximum between the corresponding dimension
  1091. /// of the two inputs.
  1092. /// @param flags - binary features of the Squared Difference Node. No supported flags are currently defined.
  1093. enum xnn_status xnn_define_squared_difference(
  1094. xnn_subgraph_t subgraph,
  1095. uint32_t input1_id,
  1096. uint32_t input2_id,
  1097. uint32_t output_id,
  1098. uint32_t flags);
  1099. /// Define a Constant Pad Node with static padding specification and add it to a Subgraph.
  1100. ///
  1101. /// @param subgraph - a Subgraph object that will own the created Node.
  1102. /// @param pre_paddings - number of padding elements to insert before input elements for every dimension. This array
  1103. /// must have as many elements as the number of dimensions in the input tensor.
  1104. /// @param post_paddings - number of padding elements to insert after input elements for every dimension. This array
  1105. /// must have as many elements as the number of dimensions in the input tensor.
  1106. /// @param padding_value - constant value used to initialize padding elements.
  1107. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1108. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1109. /// shape must match the shape of the input tensor with padding.
  1110. /// @param flags - binary features of the Constant Pad Node. No supported flags are currently defined.
  1111. enum xnn_status xnn_define_static_constant_pad(
  1112. xnn_subgraph_t subgraph,
  1113. const size_t* pre_paddings,
  1114. const size_t* post_paddings,
  1115. float padding_value,
  1116. uint32_t input_id,
  1117. uint32_t output_id,
  1118. uint32_t flags);
  1119. /// Define a Mean Node and add it to a Subgraph.
  1120. ///
  1121. /// @param subgraph - a Subgraph object that will own the created Node.
  1122. /// @param num_reduction_axes - number of axes along which mean is computed.
  1123. /// @param reduction_axes - axes along which mean is computed.
  1124. /// @param input_id - Value ID for the input tensor. The input tensor must be a dense tensor with at least
  1125. /// @a num_reduction_axes dimensions defined in the @a subgraph.
  1126. /// @param output_id - Value ID for the output tensor. The output tensor must be a dense tensor defined in the
  1127. /// @a subgraph with @a num_reduction_axes fewer dimensions than the input tensor (if
  1128. /// XNN_FLAG_REDUCE_DIMS is specified), or has same dimension rank but the dimension at
  1129. /// @a reduction_axes reduced to 1 (if XNN_FLAG_REDUCE_DIMS is not specified).
  1130. /// @param flags - binary features of the Mean Node. The only currently supported value is XNN_FLAG_REDUCE_DIMS
  1131. enum xnn_status xnn_define_static_mean(
  1132. xnn_subgraph_t subgraph,
  1133. size_t num_reduction_axes,
  1134. const size_t* reduction_axes,
  1135. uint32_t input_id,
  1136. uint32_t output_id,
  1137. uint32_t flags);
  1138. /// Define a 2-Input Concatenate Node and add it to a Subgraph.
  1139. ///
  1140. /// The 2-Input Concatenate Node concatenates two tensors along a specified axis.
  1141. ///
  1142. /// @param subgraph - a Subgraph object that will own the created Node.
  1143. /// @param axis - the axis to concatenate the two input tensors along
  1144. /// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in
  1145. /// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the
  1146. /// second input.
  1147. /// @param input2_id - Value ID for the second input tensor. The input tensor must be an N-dimensional tensor defined in
  1148. /// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the
  1149. /// first input.
  1150. /// @param output_id - Value ID for the output tensor. The output tensor must be a N-dimensional tensor defined
  1151. /// in the @a subgraph with each dimension equal to the dimension of both inputs, except the axis
  1152. /// dimension, where it is the sum of the corresponding dimensions of both inputs.
  1153. /// @param flags - binary features of the Concatenate Node. No supported flags are currently defined.
  1154. enum xnn_status xnn_define_concatenate2(
  1155. xnn_subgraph_t subgraph,
  1156. size_t axis,
  1157. uint32_t input1_id,
  1158. uint32_t input2_id,
  1159. uint32_t output_id,
  1160. uint32_t flags);
  1161. /// Define a 3-Input Concatenate Node and add it to a Subgraph.
  1162. ///
  1163. /// The 3-Input Concatenate Node concatenates three tensors along a specified axis.
  1164. ///
  1165. /// @param subgraph - a Subgraph object that will own the created Node.
  1166. /// @param axis - the axis to concatenate the three input tensors along
  1167. /// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in
  1168. /// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the
  1169. /// other inputs.
  1170. /// @param input2_id - Value ID for the second input tensor. The input tensor must be an N-dimensional tensor defined in
  1171. /// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the
  1172. /// other inputs.
  1173. /// @param input3_id - Value ID for the third input tensor. The input tensor must be an N-dimensional tensor defined in
  1174. /// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the
  1175. /// other inputs.
  1176. /// @param output_id - Value ID for the output tensor. The output tensor must be a N-dimensional tensor defined
  1177. /// in the @a subgraph with each dimension equal to the dimension of all inputs, except the axis
  1178. /// dimension, where it is the sum of the corresponding dimensions of all inputs.
  1179. /// @param flags - binary features of the Concatenate Node. No supported flags are currently defined.
  1180. enum xnn_status xnn_define_concatenate3(
  1181. xnn_subgraph_t subgraph,
  1182. size_t axis,
  1183. uint32_t input1_id,
  1184. uint32_t input2_id,
  1185. uint32_t input3_id,
  1186. uint32_t output_id,
  1187. uint32_t flags);
  1188. /// Define a 4-Input Concatenate Node and add it to a Subgraph.
  1189. ///
  1190. /// The 4-Input Concatenate Node concatenates four tensors along a specified axis.
  1191. ///
  1192. /// @param subgraph - a Subgraph object that will own the created Node.
  1193. /// @param axis - the axis to concatenate the four input tensors along
  1194. /// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in
  1195. /// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the
  1196. /// other inputs.
  1197. /// @param input2_id - Value ID for the second input tensor. The input tensor must be an N-dimensional tensor defined in
  1198. /// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the
  1199. /// other inputs.
  1200. /// @param input3_id - Value ID for the third input tensor. The input tensor must be an N-dimensional tensor defined in
  1201. /// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the
  1202. /// other inputs.
  1203. /// @param input4_id - Value ID for the fourth input tensor. The input tensor must be an N-dimensional tensor defined in
  1204. /// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the
  1205. /// other inputs.
  1206. /// @param output_id - Value ID for the output tensor. The output tensor must be a N-dimensional tensor defined
  1207. /// in the @a subgraph with each dimension equal to the dimension of all inputs, except the axis
  1208. /// dimension, where it is the sum of the corresponding dimensions of all inputs.
  1209. /// @param flags - binary features of the Concatenate Node. No supported flags are currently defined.
  1210. enum xnn_status xnn_define_concatenate4(
  1211. xnn_subgraph_t subgraph,
  1212. size_t axis,
  1213. uint32_t input1_id,
  1214. uint32_t input2_id,
  1215. uint32_t input3_id,
  1216. uint32_t input4_id,
  1217. uint32_t output_id,
  1218. uint32_t flags);
  1219. enum xnn_status xnn_define_concatenate5(
  1220. xnn_subgraph_t subgraph,
  1221. size_t axis,
  1222. uint32_t input1_id,
  1223. uint32_t input2_id,
  1224. uint32_t input3_id,
  1225. uint32_t input4_id,
  1226. uint32_t input5_id,
  1227. uint32_t output_id,
  1228. uint32_t flags);
  1229. /// Define a Copy Node and add it to a Subgraph.
  1230. ///
  1231. /// The Copy Node copies an input tensor to an output tensor.
  1232. ///
  1233. /// @param subgraph - a Subgraph object that will own the created Node.
  1234. /// @param input_id - Value ID for the first input tensor. The input tensor must be defined in the @a subgraph.
  1235. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1236. /// shape must match the shape of the input tensor.
  1237. /// @param flags - binary features of the Copy Node. No supported flags are currently defined.
  1238. enum xnn_status xnn_define_copy(
  1239. xnn_subgraph_t subgraph,
  1240. uint32_t input_id,
  1241. uint32_t output_id,
  1242. uint32_t flags);
  1243. /// Define a 2-Output Split Node and add it to a Subgraph.
  1244. ///
  1245. /// The 2-Output Split Node splits an input tensor into two output tensors along a specified axis evenly.
  1246. ///
  1247. /// @param subgraph - a Subgraph object that will own the created Node.
  1248. /// @param split_dim - the dimension to split the input tensor along
  1249. /// @param input_id - Value ID for the input tensor. The input tensor must be an N-dimensional tensor defined in the @a
  1250. /// subgraph.
  1251. /// @param output1_id - Value ID for the first output tensor. The output tensor must be an N-dimensional tensor defined
  1252. /// in the @a subgraph with each dimension, except the axis, equal to the corresponding dimension
  1253. /// of the second output. The split_dim dimension is half of the input's split_dim.
  1254. /// @param output2_id - Value ID for the second output tensor. The output tensor must be an N-dimensional tensor
  1255. /// defined in the @a subgraph with each dimension, except the axis, equal to the corresponding
  1256. /// dimension of the first output. The split_dim dimension is half of the input's split_dim.
  1257. /// @param flags - binary features of the Split Node. No supported flags are currently defined.
  1258. enum xnn_status xnn_define_even_split2(
  1259. xnn_subgraph_t subgraph,
  1260. size_t split_dim,
  1261. uint32_t input_id,
  1262. uint32_t output1_id,
  1263. uint32_t output2_id,
  1264. uint32_t flags);
  1265. /// Define a 3-Output Split Node and add it to a Subgraph.
  1266. ///
  1267. /// The 3-Output Split Node splits an input tensor into three output tensors along a specified axis evenly.
  1268. ///
  1269. /// @param subgraph - a Subgraph object that will own the created Node.
  1270. /// @param split_dim - the dimension to split the input tensor along
  1271. /// @param input_id - Value ID for the input tensor. The input tensor must be an N-dimensional tensor defined in the @a
  1272. /// subgraph.
  1273. /// @param output1_id - Value ID for the first output tensor. The output tensor must be an N-dimensional tensor defined
  1274. /// in the @a subgraph with each dimension, except the axis, equal to the corresponding dimension
  1275. /// of the second and third output. The split_dim dimension is one third of the input's split_dim.
  1276. /// @param output2_id - Value ID for the second output tensor. The output tensor must be an N-dimensional tensor
  1277. /// defined in the @a subgraph with each dimension, except the axis, equal to the corresponding
  1278. /// dimension of the first and third output. The split_dim dimension is one third of the input's
  1279. /// split_dim.
  1280. /// @param output3_id - Value ID for the third output tensor. The output tensor must be an N-dimensional tensor
  1281. /// defined in the @a subgraph with each dimension, except the axis, equal to the corresponding
  1282. /// dimension of the second and third output. The split_dim dimension is one third of the input's
  1283. /// split_dim.
  1284. /// @param flags - binary features of the Split Node. No supported flags are currently defined.
  1285. enum xnn_status xnn_define_even_split3(
  1286. xnn_subgraph_t subgraph,
  1287. size_t split_dim,
  1288. uint32_t input_id,
  1289. uint32_t output1_id,
  1290. uint32_t output2_id,
  1291. uint32_t output3_id,
  1292. uint32_t flags);
  1293. /// Define a 4-Output Split Node and add it to a Subgraph.
  1294. ///
  1295. /// The 4-Output Split Node splits an input tensor into four output tensors along a specified axis evenly.
  1296. ///
  1297. /// @param subgraph - a Subgraph object that will own the created Node.
  1298. /// @param split_dim - the dimension to split the input tensor along
  1299. /// @param input_id - Value ID for the input tensor. The input tensor must be an N-dimensional tensor defined in the @a
  1300. /// subgraph.
  1301. /// @param output1_id - Value ID for the first output tensor. The output tensor must be an N-dimensional tensor defined
  1302. /// in the @a subgraph with each dimension, except the axis, equal to the corresponding dimension
  1303. /// of the other output tensors. The split_dim dimension is one fourth of the input's split_dim.
  1304. /// @param output2_id - Value ID for the second output tensor. The output tensor must be an N-dimensional tensor
  1305. /// defined in the @a subgraph with each dimension, except the axis, equal to the corresponding
  1306. /// dimension of the other output tensors. The split_dim dimension is one fourth of the input's
  1307. /// split_dim.
  1308. /// @param output3_id - Value ID for the third output tensor. The output tensor must be an N-dimensional tensor
  1309. /// defined in the @a subgraph with each dimension, except the axis, equal to the corresponding
  1310. /// dimension of the other output tensors. The split_dim dimension is one fourth of the input's
  1311. /// split_dim.
  1312. /// @param output4_id - Value ID for the fourth output tensor. The output tensor must be an N-dimensional tensor
  1313. /// defined in the @a subgraph with each dimension, except the axis, equal to the corresponding
  1314. /// dimension of the other output tensors. The split_dim dimension is one fourth of the input's
  1315. /// split_dim.
  1316. /// @param flags - binary features of the Split Node. No supported flags are currently defined.
  1317. enum xnn_status xnn_define_even_split4(
  1318. xnn_subgraph_t subgraph,
  1319. size_t split_dim,
  1320. uint32_t input_id,
  1321. uint32_t output1_id,
  1322. uint32_t output2_id,
  1323. uint32_t output3_id,
  1324. uint32_t output4_id,
  1325. uint32_t flags);
  1326. /// Define a Reshape Node with static shape specification and add it to a Subgraph.
  1327. ///
  1328. /// @param subgraph - a Subgraph object that will own the created Node.
  1329. /// @param num_dims - number of shape dimensions in the output tensor.
  1330. /// @param new_shape - shape dimensions of the output tensor.
  1331. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1332. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1333. /// shape must match the shape of the input tensor with padding.
  1334. /// @param flags - binary features of the Reshape Node. No supported flags are currently defined.
  1335. enum xnn_status xnn_define_static_reshape(
  1336. xnn_subgraph_t subgraph,
  1337. size_t num_dims,
  1338. const size_t* new_shape,
  1339. uint32_t input_id,
  1340. uint32_t output_id,
  1341. uint32_t flags);
  1342. /// Define a Node that reshapes a tensor to two dimensions, retaining the
  1343. /// trailing dimension, and add it to a Subgraph.
  1344. ///
  1345. /// This operator is experimental.
  1346. ///
  1347. /// @param subgraph - a Subgraph object that will own the created Node.
  1348. /// @param input_id - Value ID for the input tensor. The input tensor must be
  1349. /// defined in the @a subgraph.
  1350. /// @param output_id - Value ID for the output tensor. The output tensor must be
  1351. /// defined in the @a subgraph, and its
  1352. /// size must match the shape of the input tensor with
  1353. /// padding.
  1354. /// @param flags - binary features of the Reshape Node. No supported flags are
  1355. /// currently defined.
  1356. enum xnn_status xnn_define_reshape_2d(xnn_subgraph_t subgraph,
  1357. uint32_t input_id, uint32_t output_id,
  1358. uint32_t flags);
  1359. /// Define a 2D Resize Bilinear Node with static output height & width specification and add it to a Subgraph.
  1360. ///
  1361. /// @param subgraph - a Subgraph object that will own the created Node.
  1362. /// @param new_height - height dimension of the output tensor.
  1363. /// @param new_width - width dimension of the output tensor.
  1364. /// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph
  1365. /// with [N, H, W, C] dimensions.
  1366. /// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph
  1367. /// with [N, new_height, new_width, C] dimensions.
  1368. /// @param flags - binary features of the 2D Resize Bilinear Node. The only currently supported values are
  1369. /// XNN_FLAG_TENSORFLOW_LEGACY_MODE and XNN_FLAG_ALIGN_CORNERS, which are mutually exclusive.
  1370. enum xnn_status xnn_define_static_resize_bilinear_2d(
  1371. xnn_subgraph_t subgraph,
  1372. size_t new_height,
  1373. size_t new_width,
  1374. uint32_t input_id,
  1375. uint32_t output_id,
  1376. uint32_t flags);
  1377. /// Define a PReLU (Parametric ReLU) Node and add it to a Subgraph.
  1378. ///
  1379. /// @param subgraph - a Subgraph object that will own the created Node.
  1380. /// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph
  1381. /// with [N, H, W, channels] dimensions.
  1382. /// @param slope_id - Value ID for the slope tensor. The slope tensor must be a 1D tensor defined in the @a subgraph with
  1383. /// [channels] dimensions.
  1384. /// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph
  1385. /// with [N, H, W, channels] dimensions.
  1386. /// @param flags - binary features of the PReLU Node. No supported flags are currently defined.
  1387. enum xnn_status xnn_define_prelu(
  1388. xnn_subgraph_t subgraph,
  1389. uint32_t input_id,
  1390. uint32_t slope_id,
  1391. uint32_t output_id,
  1392. uint32_t flags);
  1393. /// Define a RoPE (Rotary Positional Embeddings) Node and add it to a Subgraph.
  1394. ///
  1395. /// @param subgraph - a Subgraph object that will own the created Node.
  1396. /// @param max_tokens - maximum possible number of tokens (maximum sequence length) of the input/output tensors.
  1397. /// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph
  1398. /// with [batch, tokens, heads, channels] dimensions.
  1399. /// @param weights_id - Value ID for the weights tensor. The weights tensor must be a 2D tensor defined in the
  1400. /// @a subgraph with [max_tokens, channels] dimensions.
  1401. /// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph
  1402. /// with [batch, tokens, heads, channels] dimensions.
  1403. /// @param flags - binary features of the RoPE Node. No supported flags are currently defined.
  1404. enum xnn_status xnn_define_rope(
  1405. xnn_subgraph_t subgraph,
  1406. size_t max_sequence_size,
  1407. uint32_t input_id,
  1408. uint32_t weights_id,
  1409. uint32_t output_id,
  1410. uint32_t flags);
  1411. /// Define a Abs Node and add it to a Subgraph.
  1412. ///
  1413. /// @param subgraph - a Subgraph object that will own the created Node.
  1414. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1415. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1416. /// shape must match the shape of the input tensor.
  1417. /// @param flags - binary features of the Abs Node. No supported flags are currently defined.
  1418. enum xnn_status xnn_define_abs(
  1419. xnn_subgraph_t subgraph,
  1420. uint32_t input_id,
  1421. uint32_t output_id,
  1422. uint32_t flags);
  1423. /// Define a Bankers' Rounding Node and add it to a Subgraph.
  1424. ///
  1425. /// @param subgraph - a Subgraph object that will own the created Node.
  1426. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1427. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1428. /// shape must match the shape of the input tensor.
  1429. /// @param flags - binary features of the Bankers' Rounding Node. No supported flags are currently defined.
  1430. enum xnn_status xnn_define_bankers_rounding(
  1431. xnn_subgraph_t subgraph,
  1432. uint32_t input_id,
  1433. uint32_t output_id,
  1434. uint32_t flags);
  1435. /// Define a Batch Matrix Multiply Node and add it to a Subgraph.
  1436. ///
  1437. /// @param subgraph - a Subgraph object that will own the created Node.
  1438. /// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in
  1439. /// the @a subgraph. It must be at least 3D. The first N-2 dimensions must match the second input
  1440. /// tensor. The last 2 dimensions are [M, K]. If XNN_FLAG_TRANSPOSE_B is not specified, the last
  1441. /// dimension must match the second last dimension of the second input tensor. If
  1442. /// XNN_FLAG_TRANSPOSE_B is specified, the last dimension must match the last dimension of the
  1443. /// second input tensor.
  1444. /// @param input2_id - Value ID for the second input tensor. The input tensor must be an N-dimensional tensor defined
  1445. /// in the @a subgraph. It must be at least 3D. The first N-2 dimensions must match the first input
  1446. /// tensor. If XNN_FLAG_TRANSPOSE_B is not specified, the last 2 dimensions are [K, N], and the
  1447. /// second last dimension must match the last dimension of the first input tensor. If
  1448. /// XNN_FLAG_TRANSPOSE_B is specified, the last 2 dimensions are [N, K], and the last dimension must
  1449. /// match the last dimension of the first input tensor.
  1450. /// @param output_id - Value ID for the output tensor. The output tensor must be an N-dimensional tensor defined in the
  1451. /// @a subgraph. It must be at least 3D. The first N-2 dimensions must match the first and second
  1452. /// input tensors . The last 2 dimensions must be [M, N].
  1453. /// @param flags - binary features of the Batch Matrix Multiply Node. The only currently supported value is
  1454. /// XNN_FLAG_TRANSPOSE_B.
  1455. enum xnn_status xnn_define_batch_matrix_multiply(
  1456. xnn_subgraph_t subgraph,
  1457. uint32_t input1_id,
  1458. uint32_t input2_id,
  1459. uint32_t output_id,
  1460. uint32_t flags);
  1461. /// Define a Ceiling Node and add it to a Subgraph.
  1462. ///
  1463. /// @param subgraph - a Subgraph object that will own the created Node.
  1464. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1465. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1466. /// shape must match the shape of the input tensor.
  1467. /// @param flags - binary features of the Ceiling Node. No supported flags are currently defined.
  1468. enum xnn_status xnn_define_ceiling(
  1469. xnn_subgraph_t subgraph,
  1470. uint32_t input_id,
  1471. uint32_t output_id,
  1472. uint32_t flags);
  1473. /// Define a Clamp Node and add it to a Subgraph.
  1474. ///
  1475. /// @param subgraph - a Subgraph object that will own the created Node.
  1476. /// @param output_min - lower bound for clipping output values.
  1477. /// @param output_max - upper bound for clipping output values.
  1478. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1479. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1480. /// shape must match the shape of the input tensor.
  1481. /// @param flags - binary features of the Clamp Node. No supported flags are currently defined.
  1482. enum xnn_status xnn_define_clamp(
  1483. xnn_subgraph_t subgraph,
  1484. float output_min,
  1485. float output_max,
  1486. uint32_t input_id,
  1487. uint32_t output_id,
  1488. uint32_t flags);
  1489. /// Define an ELU (Exponential Linear Unit) Node and add it to a Subgraph.
  1490. ///
  1491. /// @param subgraph - a Subgraph object that will own the created Node.
  1492. /// @param alpha - scale factor for negative output elements.
  1493. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1494. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1495. /// shape must match the shape of the input tensor.
  1496. /// @param flags - binary features of the ELU Node. No supported flags are currently defined.
  1497. enum xnn_status xnn_define_elu(
  1498. xnn_subgraph_t subgraph,
  1499. float alpha,
  1500. uint32_t input_id,
  1501. uint32_t output_id,
  1502. uint32_t flags);
  1503. /// Define a Floor Node and add it to a Subgraph.
  1504. ///
  1505. /// @param subgraph - a Subgraph object that will own the created Node.
  1506. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1507. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1508. /// shape must match the shape of the input tensor.
  1509. /// @param flags - binary features of the Floor Node. No supported flags are currently defined.
  1510. enum xnn_status xnn_define_floor(
  1511. xnn_subgraph_t subgraph,
  1512. uint32_t input_id,
  1513. uint32_t output_id,
  1514. uint32_t flags);
  1515. /// Define a HardSwish Node and add it to a Subgraph.
  1516. ///
  1517. /// @param subgraph - a Subgraph object that will own the created Node.
  1518. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1519. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1520. /// shape must match the shape of the input tensor.
  1521. /// @param flags - binary features of the HardSwish Node. No supported flags are currently defined.
  1522. enum xnn_status xnn_define_hardswish(
  1523. xnn_subgraph_t subgraph,
  1524. uint32_t input_id,
  1525. uint32_t output_id,
  1526. uint32_t flags);
  1527. /// Define a Leaky ReLU Node and add it to a Subgraph.
  1528. ///
  1529. /// @param subgraph - a Subgraph object that will own the created Node.
  1530. /// @param negative_slope - scale factor for negative input elements.
  1531. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1532. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1533. /// shape must match the shape of the input tensor.
  1534. /// @param flags - binary features of the Leaky ReLU Node. No supported flags are currently defined.
  1535. enum xnn_status xnn_define_leaky_relu(
  1536. xnn_subgraph_t subgraph,
  1537. float negative_slope,
  1538. uint32_t input_id,
  1539. uint32_t output_id,
  1540. uint32_t flags);
  1541. /// Define a Negate Node and add it to a Subgraph.
  1542. ///
  1543. /// @param subgraph - a Subgraph object that will own the created Node.
  1544. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1545. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1546. /// shape must match the shape of the input tensor.
  1547. /// @param flags - binary features of the Negate Node. No supported flags are currently defined.
  1548. enum xnn_status xnn_define_negate(
  1549. xnn_subgraph_t subgraph,
  1550. uint32_t input_id,
  1551. uint32_t output_id,
  1552. uint32_t flags);
  1553. /// Define a Sigmoid Node and add it to a Subgraph.
  1554. ///
  1555. /// @param subgraph - a Subgraph object that will own the created Node.
  1556. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1557. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1558. /// shape must match the shape of the input tensor.
  1559. /// @param flags - binary features of the Sigmoid Node. No supported flags are currently defined.
  1560. enum xnn_status xnn_define_sigmoid(
  1561. xnn_subgraph_t subgraph,
  1562. uint32_t input_id,
  1563. uint32_t output_id,
  1564. uint32_t flags);
  1565. /// Define a SoftMax Node and add it to a Subgraph.
  1566. ///
  1567. /// @param subgraph - a Subgraph object that will own the created Node.
  1568. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph, and have at
  1569. /// least one dimension.
  1570. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1571. /// shape must match the shape of the input tensor.
  1572. /// @param flags - binary features of the SoftMax Node. No supported flags are currently defined.
  1573. enum xnn_status xnn_define_softmax(
  1574. xnn_subgraph_t subgraph,
  1575. uint32_t input_id,
  1576. uint32_t output_id,
  1577. uint32_t flags);
  1578. /// Define a Space To Depth 2D Node and add it to a Subgraph.
  1579. ///
  1580. /// The Space To Depth 2D Node rearranges blocks of spatial data into blocks (a reverse transform to Depth To Space 2D).
  1581. /// For a given input pixel, an output square of pixels with side @a block_size is formed from values in the
  1582. /// corresponding number of its channels. The output depth is therefore @a block_size x @a block_size times greater
  1583. /// than that of the input.
  1584. ///
  1585. /// @param subgraph - a Subgraph object that will own the created Node.
  1586. /// @param block_size - the size of the spatial block.
  1587. /// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph
  1588. /// with [N, IH * block_size, IW * block_size, OC] dimensions.
  1589. /// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph
  1590. /// with [N, IH, IW, OC * block_size * block_size] dimensions.
  1591. /// @param flags - binary features of the input_channels Node. No supported flags are currently defined.
  1592. enum xnn_status xnn_define_space_to_depth_2d(
  1593. xnn_subgraph_t subgraph,
  1594. uint32_t block_size,
  1595. uint32_t input_id,
  1596. uint32_t output_id,
  1597. uint32_t flags);
  1598. /// Define a Square Node and add it to a Subgraph.
  1599. ///
  1600. /// @param subgraph - a Subgraph object that will own the created Node.
  1601. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1602. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1603. /// shape must match the shape of the input tensor.
  1604. /// @param flags - binary features of the Square Node. No supported flags are currently defined.
  1605. enum xnn_status xnn_define_square(
  1606. xnn_subgraph_t subgraph,
  1607. uint32_t input_id,
  1608. uint32_t output_id,
  1609. uint32_t flags);
  1610. /// Define a Square Root Node and add it to a Subgraph.
  1611. ///
  1612. /// @param subgraph - a Subgraph object that will own the created Node.
  1613. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1614. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1615. /// shape must match the shape of the input tensor.
  1616. /// @param flags - binary features of the Square Root Node. No supported flags are currently defined.
  1617. enum xnn_status xnn_define_square_root(
  1618. xnn_subgraph_t subgraph,
  1619. uint32_t input_id,
  1620. uint32_t output_id,
  1621. uint32_t flags);
  1622. /// Define a Reciprocal Square Root Node and add it to a Subgraph.
  1623. ///
  1624. /// @param subgraph - a Subgraph object that will own the created Node.
  1625. /// @param input_id - Value ID for the input tensor. The input tensor must be
  1626. /// defined in the @a subgraph.
  1627. /// @param output_id - Value ID for the output tensor. The output tensor must be
  1628. /// defined in the @a subgraph, and its
  1629. /// shape must match the shape of the input tensor.
  1630. /// @param flags - binary features of the Square Root Node. No supported flags
  1631. /// are currently defined.
  1632. enum xnn_status xnn_define_reciprocal_square_root(xnn_subgraph_t subgraph,
  1633. uint32_t input_id,
  1634. uint32_t output_id,
  1635. uint32_t flags);
  1636. /// Define a Static Slice Node add it to a Subgraph.
  1637. ///
  1638. /// @param subgraph - a Subgraph object that will own the created Node.
  1639. /// @param num_dims - number of shape dimensions in the input and output tensor.
  1640. /// @param offsets - offsets in each dimension of the input tensor. This array must have @a num_dims elements.
  1641. /// @param sizes - size of each dimension in output tensor. This array must have @a num_dims elements.
  1642. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1643. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1644. /// dimensions must match @a sizes.
  1645. /// @param flags - binary features of the Static Slice Node. No supported flags are currently defined.
  1646. enum xnn_status xnn_define_static_slice(
  1647. xnn_subgraph_t subgraph,
  1648. size_t num_dims,
  1649. const size_t* offsets,
  1650. const size_t* sizes,
  1651. uint32_t input_id,
  1652. uint32_t output_id,
  1653. uint32_t flags);
  1654. /// Define a Static Transpose Node and add it to a Subgraph.
  1655. ///
  1656. /// The Static Transpose Node applies a generalized transpose to the input tensor using the permuation in perm.
  1657. ///
  1658. /// @param subgraph - a Subgraph object that will own the created Node.
  1659. /// @param input_id - Value ID for the input tensor. The input tensor must be an N-dimensional tensor defined in
  1660. /// the @a subgraph.
  1661. /// @param output_id - Value ID for the output tensor. The output tensor must be an N-dimensional tensor defined
  1662. /// in the @a subgraph with each dimension equal to its corresponding permuted input dimension.
  1663. /// @param num_dims - the number of permutation dimensions. This must be equal to the number of input dimensions.
  1664. /// @param perm - The permutation of the axis of the input tensor. The perm array must must contain 0 to N-1 in the
  1665. /// permuted order.
  1666. /// @param flags - binary features of the Static Transpose Node. No supported flags are currently defined.
  1667. enum xnn_status xnn_define_static_transpose(
  1668. xnn_subgraph_t subgraph,
  1669. size_t num_dims,
  1670. const size_t* perm,
  1671. uint32_t input_id,
  1672. uint32_t output_id,
  1673. uint32_t flags);
  1674. /// Define a Tanh Node and add it to a Subgraph.
  1675. ///
  1676. /// @param subgraph - a Subgraph object that will own the created Node.
  1677. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1678. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1679. /// shape must match the shape of the input tensor.
  1680. /// @param flags - binary features of the Tanh Node. No supported flags are currently defined.
  1681. enum xnn_status xnn_define_tanh(
  1682. xnn_subgraph_t subgraph,
  1683. uint32_t input_id,
  1684. uint32_t output_id,
  1685. uint32_t flags);
  1686. /// Code cache is a cache for JIT generated code.
  1687. typedef struct xnn_code_cache* xnn_code_cache_t;
  1688. /// Weights cache can be finalized in these ways:
  1689. enum xnn_weights_cache_finalization_kind {
  1690. /// Weights cache is finalized, no insert operations into the weights cache is allowed, even if the "inserted"
  1691. /// weights already exist in thee cache. Weights cache memory will also be trimmed to page boundary and set to
  1692. /// read-only (to prevent writes).
  1693. xnn_weights_cache_finalization_kind_hard,
  1694. /// Weights cache will be finalized with some extra space at the end, this allows for "inserting" into the cache only
  1695. /// if the weights are already in the cache, and errors on inserting uncached weights. There is memory overhead.
  1696. xnn_weights_cache_finalization_kind_soft,
  1697. };
  1698. /// A combination of multiple factors to uniquely locate the weights cache.
  1699. struct xnn_weights_cache_look_up_key {
  1700. /// The unique seed for each ukernel. It is guaranteed that each ukernel provides
  1701. /// a consistent and identical seed.
  1702. uint32_t seed;
  1703. /// Pointer to the original kernel.
  1704. const void* kernel;
  1705. /// Pointer to the original bias, could be NULL.
  1706. const void* bias;
  1707. };
  1708. /// A group of function pointers to manage weights cache. All functions may be
  1709. /// called on multi threads.
  1710. struct xnn_weights_cache_provider {
  1711. /// User-specified pointer that will be passed as-is to all functions in this
  1712. /// structure.
  1713. void* context;
  1714. /// Looks up the tuple of {cache_key, kernel, bias} in the cache. If it is found,
  1715. /// returns the offset to the found entry for reuse. Otherwise, returns SIZE_MAX.
  1716. /// @param context - The user-specified pointer from xnn_weights_cache_provider structure.
  1717. /// @param cache_key - The key used to locate the weights cache entry.
  1718. size_t (*look_up)(void* context, const struct xnn_weights_cache_look_up_key* cache_key);
  1719. /// Ensures that cache has enough space for `n` bytes. Returns the address to
  1720. /// store weight cache. Returns NULL if fails to reserve space.
  1721. /// @param context - The user-specified pointer from xnn_weights_cache_provider structure.
  1722. /// @param n - size to be reserved.
  1723. void* (*reserve_space)(void* context, size_t n);
  1724. /// Looks up packed weights at `ptr` in the cache. If it is found, reuse it.
  1725. /// Otherwise, it is added to the cache. Returns the offset to the cache.
  1726. /// @param context - The user-specified pointer from xnn_weights_cache_provider structure.
  1727. /// @param cache_key - The key used to locate the weights cache entry.
  1728. /// @param ptr - pointer pointing to the packed weight.
  1729. /// @param size - size of the packed weight.
  1730. size_t (*look_up_or_insert)(void* context, const struct xnn_weights_cache_look_up_key* cache_key, void* ptr, size_t size);
  1731. /// Returns whether the cache is finalized.
  1732. /// @param context - The user-specified pointer from xnn_weights_cache_provider structure.
  1733. bool (*is_finalized)(void* context);
  1734. /// Returns the absolute pointer corresponding to `offset`, where the offset is returned from
  1735. /// `look_up` or `get_or_insert`. This function must be called after finalize.
  1736. /// @param context - The user-specified pointer from xnn_weights_cache_provider structure.
  1737. /// @param offset - offset to the start of internal buffer
  1738. void* (*offset_to_addr)(void* context, size_t offset);
  1739. /// Destroy a weights cache object, as well as memory used for the cache.
  1740. /// @param context - The user-specified pointer from xnn_weights_cache_provider structure.
  1741. enum xnn_status (*delete_cache)(void* context);
  1742. };
  1743. /// Weights cache is a cache for packed weights. It can be reused between runtimes.
  1744. typedef struct xnn_weights_cache_provider* xnn_weights_cache_t;
  1745. /// Create a weights cache object specifying the initial size of weights cache (in bytes).
  1746. ///
  1747. /// @param[in] size - initial capacity of the weights cache (in bytes), i.e. it can hold size bytes without growing.
  1748. /// @param weights_cache_out - pointer to the variable that will be initialized to a handle to the weights cache provider
  1749. /// upon successful return. Once created, the weights cache provider can be shared between
  1750. /// different Runtime objects.
  1751. enum xnn_status xnn_create_weights_cache_with_size(size_t size, xnn_weights_cache_t* weights_cache_out);
  1752. enum xnn_status xnn_create_weights_cache(xnn_weights_cache_t* weights_cache_out);
  1753. /// Finalizes the weights cache. The kind of finalization is specified by `finalization_kind`.
  1754. /// @param weights_cache - the weights cache object to finalize.
  1755. /// @param finalization_kind - the kind of finalization.
  1756. enum xnn_status xnn_finalize_weights_cache(
  1757. xnn_weights_cache_t weights_cache,
  1758. enum xnn_weights_cache_finalization_kind finalization_kind);
  1759. /// Destroy a weights cache object, as well as memory used for the cache.
  1760. /// @param weights_cache - the weights cache object to destroy.
  1761. enum xnn_status xnn_delete_weights_cache(xnn_weights_cache_t weights_cache);
  1762. typedef struct xnn_workspace* xnn_workspace_t;
  1763. /// Create a workspace object.
  1764. /// @param workspace_out - pointer to the variable that will be initialized to a handle to the workspace object upon
  1765. /// successful return. Once created, the workspace can be shared between different Runtime
  1766. /// objects.
  1767. enum xnn_status xnn_create_workspace(xnn_workspace_t* workspace_out);
  1768. /// Destroy a workspace object, as well as memory used by the workspace. Object destruction can be deferred until all
  1769. /// Runtime objects created with this workspace are destroyed.
  1770. /// @param workspace - the workspace object to destroy.
  1771. enum xnn_status xnn_release_workspace(xnn_workspace_t workspace);
  1772. /// Runtime is a combination of an execution plan for subgraph Nodes and a memory manager for subgraph Values.
  1773. typedef struct xnn_runtime* xnn_runtime_t;
  1774. enum xnn_profile_info {
  1775. /// Returns a size_t containing the number of operators.
  1776. xnn_profile_info_num_operators,
  1777. /// Returns a char[] containing the null character separated names of all operators.
  1778. xnn_profile_info_operator_name,
  1779. /// Returns a uint64_t[] with the runtimes of all operators in the same order as xnn_profile_info_operator_name.
  1780. xnn_profile_info_operator_timing,
  1781. };
  1782. /// Return profile information for all operators.
  1783. ///
  1784. /// @param runtime - a Runtime object created with @ref xnn_create_runtime, @ref xnn_create_runtime_v2 or
  1785. /// @ref xnn_create_runtime_v3.
  1786. /// @param param_name - type of profile information required.
  1787. /// @param param_value_size - the size in bytes of memory pointed to by param_value. If this is not sufficient then
  1788. /// param_value_size_ret will be set to the required size and xnn_status_out_of_memory will be
  1789. /// returned.
  1790. /// @param param_value - a pointer to memory location where appropriate values for a given param_value will be written.
  1791. /// @param param_value_size_ret - returns number of bytes required to write the result if param_value_size is not
  1792. /// sufficient.
  1793. enum xnn_status xnn_get_runtime_profiling_info(xnn_runtime_t runtime,
  1794. enum xnn_profile_info param_name,
  1795. size_t param_value_size,
  1796. void* param_value,
  1797. size_t* param_value_size_ret);
  1798. /// Create a Runtime object from a subgraph.
  1799. ///
  1800. /// @param subgraph - a Subgraph object with all Values and Nodes that would be handled by the runtime. No Values or
  1801. /// Nodes can be added to the runtime once it is constructed.
  1802. /// @param weights_cache - a cache for packed weights. The runtime will look up and reuse packed weights in this cache,
  1803. /// this will reduce memory allocated for packed weights.
  1804. /// @param workspace - a workspace to hold internal tensors. The runtime will allocate space used for internal tensors
  1805. /// and track them using workspace. Workspace can be shared and reused across different runtimes. If
  1806. /// workspace is NULL, there will be no sharing: each runtime has its own workspace.
  1807. /// @param threadpool - the thread pool to be used for parallelisation of computations in the runtime. If the thread
  1808. /// pool is NULL, the computation would run on the caller thread without parallelization.
  1809. /// @param flags - binary features of the runtime. The only currently supported values are
  1810. /// XNN_FLAG_HINT_SPARSE_INFERENCE, XNN_FLAG_HINT_FP16_INFERENCE, XNN_FLAG_FORCE_FP16_INFERENCE,
  1811. /// XNN_FLAG_YIELD_WORKERS, and XNN_FLAG_TRANSIENT_INDIRECTION_BUFFER. If XNN_FLAG_YIELD_WORKERS is
  1812. /// specified, worker threads would be yielded to the system scheduler after processing the last operator
  1813. /// in the Runtime. If XNN_FLAG_TRANSIENT_INDIRECTION_BUFFER is specified, convolution operators will
  1814. /// initialize indirection buffers on each inference run using temporary memory in the workspace, instead
  1815. /// of initializing persistent indirection buffers once.
  1816. /// @param runtime_out - pointer to the variable that will be initialized with a handle to the Runtime object upon
  1817. /// successful return. Once constructed, the Runtime object is independent of the Subgraph object
  1818. /// used to create it.
  1819. enum xnn_status xnn_create_runtime_v4(
  1820. xnn_subgraph_t subgraph,
  1821. xnn_weights_cache_t weights_cache,
  1822. xnn_workspace_t workspace,
  1823. pthreadpool_t threadpool,
  1824. uint32_t flags,
  1825. xnn_runtime_t* runtime_out);
  1826. enum xnn_status xnn_create_runtime_v3(
  1827. xnn_subgraph_t subgraph,
  1828. xnn_weights_cache_t weights_cache,
  1829. pthreadpool_t threadpool,
  1830. uint32_t flags,
  1831. xnn_runtime_t* runtime_out);
  1832. enum xnn_status xnn_create_runtime_v2(
  1833. xnn_subgraph_t subgraph,
  1834. pthreadpool_t threadpool,
  1835. uint32_t flags,
  1836. xnn_runtime_t* runtime_out);
  1837. enum xnn_status xnn_create_runtime(
  1838. xnn_subgraph_t subgraph,
  1839. xnn_runtime_t* runtime_out);
  1840. struct xnn_external_value {
  1841. uint32_t id;
  1842. void* data;
  1843. };
  1844. /// Reshape an external value.
  1845. ///
  1846. /// @param external_id - external ID for the Value. The ID must be within the range of reversed Value IDs specified on
  1847. /// the Subgraph creation. If the external ID is XNN_INVALID_VALUE_ID, an internal ID will be
  1848. /// created for the Value.
  1849. /// @param num_dims - number of dimensions in the shape.
  1850. /// @param dims - pointer to an array of @a num_dims shape dimensions. If num_dims is 0, this pointer can be NULL.
  1851. /// XNNPACK does not keep any pointers to this array after the function returns.
  1852. enum xnn_status xnn_reshape_external_value(
  1853. xnn_runtime_t runtime,
  1854. uint32_t external_id,
  1855. size_t num_dims,
  1856. const size_t* dims);
  1857. /// Get the external value shape.
  1858. ///
  1859. /// @param external_id - external ID for the Value. The ID must be within the range of reversed Value IDs specified on
  1860. /// the Subgraph creation. The external ID can not be XNN_INVALID_VALUE_ID.
  1861. /// @param num_dims - A valid pointer into which the number of dimensions in the shape will be written. It can not be larger than XNN_MAX_TENSOR_DIMS.
  1862. /// @param dims - pointer to an array of @a num_dims shape dimensions. This pointer can't be NULL. It must be large enough to hold
  1863. /// at least @a num_dims elements. XNNPACK does not keep any pointers to this array after the function returns.
  1864. enum xnn_status xnn_get_external_value_shape(
  1865. xnn_runtime_t runtime,
  1866. uint32_t external_id,
  1867. size_t* num_dims,
  1868. size_t* dims);
  1869. /// Reshape the XNNPACK runtime.
  1870. ///
  1871. /// Propgates the shapes of input tensors through the graph to determine the shapes of intermediate and output tensors.
  1872. /// Memory is allocated if required. Output tensor shapes are returned by xnn_get_external_value_shape.
  1873. ///
  1874. /// @param runtime - a Runtime object created with @ref xnn_create_runtime or @ref xnn_create_runtime_v2.
  1875. enum xnn_status xnn_reshape_runtime(
  1876. xnn_runtime_t runtime);
  1877. /// Deprecated. Use xnn_reshape_runtime and xnn_setup_runtime_v2.
  1878. ///
  1879. /// Setup data pointers for external inputs and outputs in a Runtime object and
  1880. /// allocate memory.
  1881. ///
  1882. /// @param runtime - a Runtime object created with @ref xnn_create_runtime or @ref xnn_create_runtime_v2.
  1883. /// @param num_external_values - the number of external inputs and outputs specified in this call. This number must
  1884. /// match the number of external inputs and outputs in the runtime, i.e. all external
  1885. /// inputs and outputs in the runtime must be specified in one call.
  1886. /// @param external_values - array with location information for all external inputs and outputs in the runtime.
  1887. enum xnn_status xnn_setup_runtime(
  1888. xnn_runtime_t runtime,
  1889. size_t num_external_values,
  1890. const struct xnn_external_value* external_values);
  1891. /// Setup data pointers for external inputs and outputs in a Runtime object.
  1892. /// Should be called after xnn_reshape_runtime.
  1893. ///
  1894. /// @param runtime - a Runtime object created with @ref xnn_create_runtime or @ref xnn_create_runtime_v2.
  1895. /// @param num_external_values - the number of external inputs and outputs specified in this call. This number must
  1896. /// match the number of external inputs and outputs in the runtime, i.e. all external
  1897. /// inputs and outputs in the runtime must be specified in one call.
  1898. /// @param external_values - array with location information for all external inputs and outputs in the runtime.
  1899. enum xnn_status xnn_setup_runtime_v2(
  1900. xnn_runtime_t runtime,
  1901. size_t num_external_values,
  1902. const struct xnn_external_value* external_values);
  1903. /// Execute forward pass for all operators in the runtime.
  1904. ///
  1905. /// @param runtime - the Runtime object with the execution plan to invoke.
  1906. enum xnn_status xnn_invoke_runtime(
  1907. xnn_runtime_t runtime);
  1908. /// Destroy a Runtime object, as well as operators and memory associated with it.
  1909. ///
  1910. /// @param runtime - the Runtime object to destroy.
  1911. enum xnn_status xnn_delete_runtime(
  1912. xnn_runtime_t runtime);
  1913. typedef struct xnn_operator* xnn_operator_t;
  1914. enum xnn_status xnn_run_operator(
  1915. xnn_operator_t op,
  1916. pthreadpool_t threadpool);
  1917. enum xnn_status xnn_delete_operator(
  1918. xnn_operator_t op);
  1919. /// Operator API:
  1920. /// - create operator will create and populate a xnn_operator_t
  1921. /// - reshape operator will update fields in xnn_operator_t with shape/dimensions and parallelization information
  1922. /// - setup operator will update pointers to input and outputs
  1923. /// Each supported operator must have a create, reshape, and setup function. (Optionally a run function.)
  1924. /// Operators listed below are in alphabetical order by operator name; within each operator, we sort alphabetically by
  1925. /// data layout and type. We also group create, reshape, setup (and optionally run) functions of each operator together.
  1926. enum xnn_status xnn_create_abs_nc_f16(
  1927. uint32_t flags,
  1928. xnn_operator_t* abs_op_out);
  1929. enum xnn_status xnn_reshape_abs_nc_f16(
  1930. xnn_operator_t abs_op,
  1931. size_t batch_size,
  1932. size_t channels,
  1933. size_t input_stride,
  1934. size_t output_stride,
  1935. pthreadpool_t threadpool);
  1936. enum xnn_status xnn_setup_abs_nc_f16(
  1937. xnn_operator_t abs_op,
  1938. const void* input,
  1939. void* output);
  1940. enum xnn_status xnn_create_abs_nc_f32(
  1941. uint32_t flags,
  1942. xnn_operator_t* abs_op_out);
  1943. enum xnn_status xnn_reshape_abs_nc_f32(
  1944. xnn_operator_t abs_op,
  1945. size_t batch_size,
  1946. size_t channels,
  1947. size_t input_stride,
  1948. size_t output_stride,
  1949. pthreadpool_t threadpool);
  1950. enum xnn_status xnn_setup_abs_nc_f32(
  1951. xnn_operator_t abs_op,
  1952. const float* input,
  1953. float* output);
  1954. enum xnn_status xnn_run_abs_nc_f32(
  1955. size_t channels,
  1956. size_t input_stride,
  1957. size_t output_stride,
  1958. size_t batch_size,
  1959. const float* input,
  1960. float* output,
  1961. uint32_t flags,
  1962. pthreadpool_t threadpool);
  1963. enum xnn_status xnn_create_add_nd_f16(
  1964. float output_min,
  1965. float output_max,
  1966. uint32_t flags,
  1967. xnn_operator_t* add_op_out);
  1968. enum xnn_status xnn_reshape_add_nd_f16(
  1969. xnn_operator_t add_op,
  1970. size_t num_input1_dims,
  1971. const size_t* input1_shape,
  1972. size_t num_input2_dims,
  1973. const size_t* input2_shape,
  1974. pthreadpool_t threadpool);
  1975. enum xnn_status xnn_setup_add_nd_f16(
  1976. xnn_operator_t add_op,
  1977. const void* input1,
  1978. const void* input2,
  1979. void* output);
  1980. enum xnn_status xnn_create_add_nd_f32(
  1981. float output_min,
  1982. float output_max,
  1983. uint32_t flags,
  1984. xnn_operator_t* add_op_out);
  1985. enum xnn_status xnn_reshape_add_nd_f32(
  1986. xnn_operator_t add_op,
  1987. size_t num_input1_dims,
  1988. const size_t* input1_shape,
  1989. size_t num_input2_dims,
  1990. const size_t* input2_shape,
  1991. pthreadpool_t threadpool);
  1992. enum xnn_status xnn_setup_add_nd_f32(
  1993. xnn_operator_t add_op,
  1994. const float* input1,
  1995. const float* input2,
  1996. float* output);
  1997. enum xnn_status xnn_run_add_nd_f32(
  1998. size_t num_input1_dims,
  1999. const size_t* input1_shape,
  2000. size_t num_input2_dims,
  2001. const size_t* input2_shape,
  2002. const float* input1,
  2003. const float* input2,
  2004. float* output,
  2005. float output_min,
  2006. float output_max,
  2007. uint32_t flags,
  2008. pthreadpool_t threadpool);
  2009. enum xnn_status xnn_create_add_nd_qs8(
  2010. int8_t input1_zero_point,
  2011. float input1_scale,
  2012. int8_t input2_zero_point,
  2013. float input2_scale,
  2014. int8_t output_zero_point,
  2015. float output_scale,
  2016. int8_t output_min,
  2017. int8_t output_max,
  2018. uint32_t flags,
  2019. xnn_operator_t* add_op_out);
  2020. enum xnn_status xnn_reshape_add_nd_qs8(
  2021. xnn_operator_t add_op,
  2022. size_t num_input1_dims,
  2023. const size_t* input1_shape,
  2024. size_t num_input2_dims,
  2025. const size_t* input2_shape,
  2026. pthreadpool_t threadpool);
  2027. enum xnn_status xnn_setup_add_nd_qs8(
  2028. xnn_operator_t add_op,
  2029. const int8_t* input1,
  2030. const int8_t* input2,
  2031. int8_t* output);
  2032. enum xnn_status xnn_run_add_nd_qs8(
  2033. size_t num_input1_dims,
  2034. const size_t* input1_shape,
  2035. int8_t input1_zero_point,
  2036. float input1_scale,
  2037. size_t num_input2_dims,
  2038. const size_t* input2_shape,
  2039. int8_t input2_zero_point,
  2040. float input2_scale,
  2041. const int8_t* input1,
  2042. const int8_t* input2,
  2043. int8_t* output,
  2044. int8_t output_zero_point,
  2045. float output_scale,
  2046. int8_t output_min,
  2047. int8_t output_max,
  2048. uint32_t flags,
  2049. pthreadpool_t threadpool);
  2050. enum xnn_status xnn_create_add_nd_qu8(
  2051. uint8_t input1_zero_point,
  2052. float input1_scale,
  2053. uint8_t input2_zero_point,
  2054. float input2_scale,
  2055. uint8_t output_zero_point,
  2056. float output_scale,
  2057. uint8_t output_min,
  2058. uint8_t output_max,
  2059. uint32_t flags,
  2060. xnn_operator_t* add_op_out);
  2061. enum xnn_status xnn_reshape_add_nd_qu8(
  2062. xnn_operator_t add_op,
  2063. size_t num_input1_dims,
  2064. const size_t* input1_shape,
  2065. size_t num_input2_dims,
  2066. const size_t* input2_shape,
  2067. pthreadpool_t threadpool);
  2068. enum xnn_status xnn_setup_add_nd_qu8(
  2069. xnn_operator_t add_op,
  2070. const uint8_t* input1,
  2071. const uint8_t* input2,
  2072. uint8_t* output);
  2073. enum xnn_status xnn_run_add_nd_qu8(
  2074. size_t num_input1_dims,
  2075. const size_t* input1_shape,
  2076. uint8_t input1_zero_point,
  2077. float input1_scale,
  2078. size_t num_input2_dims,
  2079. const size_t* input2_shape,
  2080. uint8_t input2_zero_point,
  2081. float input2_scale,
  2082. const uint8_t* input1,
  2083. const uint8_t* input2,
  2084. uint8_t* output,
  2085. uint8_t output_zero_point,
  2086. float output_scale,
  2087. uint8_t output_min,
  2088. uint8_t output_max,
  2089. uint32_t flags,
  2090. pthreadpool_t threadpool);
  2091. enum xnn_status xnn_create_argmax_pooling2d_nhwc_f32(
  2092. uint32_t input_padding_top,
  2093. uint32_t input_padding_right,
  2094. uint32_t input_padding_bottom,
  2095. uint32_t input_padding_left,
  2096. uint32_t pooling_height,
  2097. uint32_t pooling_width,
  2098. uint32_t flags,
  2099. xnn_operator_t* argmax_pooling_op_out);
  2100. enum xnn_status xnn_reshape_argmax_pooling2d_nhwc_f32(
  2101. xnn_operator_t argmax_pooling_op,
  2102. size_t batch_size,
  2103. size_t input_height,
  2104. size_t input_width,
  2105. size_t channels,
  2106. size_t input_pixel_stride,
  2107. size_t output_pixel_stride,
  2108. size_t* workspace_size,
  2109. size_t* workspace_alignment,
  2110. size_t* output_height_out,
  2111. size_t* output_width_out,
  2112. pthreadpool_t threadpool);
  2113. enum xnn_status xnn_setup_argmax_pooling2d_nhwc_f32(
  2114. xnn_operator_t argmax_pooling_op,
  2115. void* workspace,
  2116. const float* input,
  2117. float* output,
  2118. uint32_t* index);
  2119. enum xnn_status xnn_create_average_pooling2d_nhwc_f16(
  2120. uint32_t input_padding_top,
  2121. uint32_t input_padding_right,
  2122. uint32_t input_padding_bottom,
  2123. uint32_t input_padding_left,
  2124. uint32_t pooling_height,
  2125. uint32_t pooling_width,
  2126. uint32_t stride_height,
  2127. uint32_t stride_width,
  2128. float output_min,
  2129. float output_max,
  2130. uint32_t flags,
  2131. xnn_operator_t* average_pooling_op_out);
  2132. enum xnn_status xnn_reshape_average_pooling2d_nhwc_f16(
  2133. xnn_operator_t average_pooling_op,
  2134. size_t batch_size,
  2135. size_t input_height,
  2136. size_t input_width,
  2137. size_t channels,
  2138. size_t input_pixel_stride,
  2139. size_t output_pixel_stride,
  2140. size_t* workspace_size,
  2141. size_t* workspace_alignment,
  2142. size_t* output_height_out,
  2143. size_t* output_width_out,
  2144. pthreadpool_t threadpool);
  2145. enum xnn_status xnn_setup_average_pooling2d_nhwc_f16(
  2146. xnn_operator_t average_pooling_op,
  2147. void* workspace,
  2148. const void* input,
  2149. void* output);
  2150. enum xnn_status xnn_create_average_pooling2d_nhwc_f32(
  2151. uint32_t input_padding_top,
  2152. uint32_t input_padding_right,
  2153. uint32_t input_padding_bottom,
  2154. uint32_t input_padding_left,
  2155. uint32_t pooling_height,
  2156. uint32_t pooling_width,
  2157. uint32_t stride_height,
  2158. uint32_t stride_width,
  2159. float output_min,
  2160. float output_max,
  2161. uint32_t flags,
  2162. xnn_operator_t* average_pooling_op_out);
  2163. enum xnn_status xnn_reshape_average_pooling2d_nhwc_f32(
  2164. xnn_operator_t average_pooling_op,
  2165. size_t batch_size,
  2166. size_t input_height,
  2167. size_t input_width,
  2168. size_t channels,
  2169. size_t input_pixel_stride,
  2170. size_t output_pixel_stride,
  2171. size_t* workspace_size,
  2172. size_t* workspace_alignment,
  2173. size_t* output_height_out,
  2174. size_t* output_width_out,
  2175. pthreadpool_t threadpool);
  2176. enum xnn_status xnn_setup_average_pooling2d_nhwc_f32(
  2177. xnn_operator_t average_pooling_op,
  2178. void* workspace,
  2179. const float* input,
  2180. float* output);
  2181. enum xnn_status xnn_create_average_pooling2d_nhwc_qu8(
  2182. uint32_t input_padding_top,
  2183. uint32_t input_padding_right,
  2184. uint32_t input_padding_bottom,
  2185. uint32_t input_padding_left,
  2186. uint32_t pooling_height,
  2187. uint32_t pooling_width,
  2188. uint32_t stride_height,
  2189. uint32_t stride_width,
  2190. uint8_t input_zero_point,
  2191. float input_scale,
  2192. uint8_t output_zero_point,
  2193. float output_scale,
  2194. uint8_t output_min,
  2195. uint8_t output_max,
  2196. uint32_t flags,
  2197. xnn_operator_t* average_pooling_op_out);
  2198. enum xnn_status xnn_reshape_average_pooling2d_nhwc_qu8(
  2199. xnn_operator_t average_pooling_op,
  2200. size_t batch_size,
  2201. size_t input_height,
  2202. size_t input_width,
  2203. size_t channels,
  2204. size_t input_pixel_stride,
  2205. size_t output_pixel_stride,
  2206. size_t* workspace_size,
  2207. size_t* workspace_alignment,
  2208. size_t* output_height_out,
  2209. size_t* output_width_out,
  2210. pthreadpool_t threadpool);
  2211. enum xnn_status xnn_setup_average_pooling2d_nhwc_qu8(
  2212. xnn_operator_t average_pooling_op,
  2213. void* workspace,
  2214. const uint8_t* input,
  2215. uint8_t* output);
  2216. enum xnn_status xnn_create_bankers_rounding_nc_f16(
  2217. uint32_t flags,
  2218. xnn_operator_t* rounding_op_out);
  2219. enum xnn_status xnn_reshape_bankers_rounding_nc_f16(
  2220. xnn_operator_t rounding_op,
  2221. size_t batch_size,
  2222. size_t channels,
  2223. size_t input_stride,
  2224. size_t output_stride,
  2225. pthreadpool_t threadpool);
  2226. enum xnn_status xnn_setup_bankers_rounding_nc_f16(
  2227. xnn_operator_t rounding_op,
  2228. const void* input,
  2229. void* output);
  2230. enum xnn_status xnn_create_bankers_rounding_nc_f32(
  2231. uint32_t flags,
  2232. xnn_operator_t* rounding_op_out);
  2233. enum xnn_status xnn_reshape_bankers_rounding_nc_f32(
  2234. xnn_operator_t rounding_op,
  2235. size_t batch_size,
  2236. size_t channels,
  2237. size_t input_stride,
  2238. size_t output_stride,
  2239. pthreadpool_t threadpool);
  2240. enum xnn_status xnn_setup_bankers_rounding_nc_f32(
  2241. xnn_operator_t rounding_op,
  2242. const float* input,
  2243. float* output);
  2244. enum xnn_status xnn_run_bankers_rounding_nc_f32(
  2245. size_t channels,
  2246. size_t input_stride,
  2247. size_t output_stride,
  2248. size_t batch_size,
  2249. const float* input,
  2250. float* output,
  2251. uint32_t flags,
  2252. pthreadpool_t threadpool);
  2253. enum xnn_status xnn_create_batch_matrix_multiply_nc_f16(
  2254. uint32_t flags,
  2255. xnn_operator_t* batch_matrix_multiply_op);
  2256. enum xnn_status xnn_reshape_batch_matrix_multiply_nc_f16(
  2257. xnn_operator_t batch_matrix_multiply_op,
  2258. size_t batch_size,
  2259. size_t m,
  2260. size_t k,
  2261. size_t n,
  2262. size_t* workspace_size,
  2263. size_t* workspace_alignment,
  2264. pthreadpool_t threadpool);
  2265. enum xnn_status xnn_setup_batch_matrix_multiply_nc_f16(
  2266. xnn_operator_t batch_matrix_multiply_op,
  2267. void* workspace,
  2268. const void* lhs_input,
  2269. const void* rhs_input,
  2270. void* output);
  2271. enum xnn_status xnn_create_batch_matrix_multiply_nc_f32(
  2272. uint32_t flags,
  2273. xnn_operator_t* batch_matrix_multiply_op);
  2274. enum xnn_status xnn_reshape_batch_matrix_multiply_nc_f32(
  2275. xnn_operator_t batch_matrix_multiply_op,
  2276. size_t batch_size,
  2277. size_t m,
  2278. size_t k,
  2279. size_t n,
  2280. size_t* workspace_size,
  2281. size_t* workspace_alignment,
  2282. pthreadpool_t threadpool);
  2283. enum xnn_status xnn_setup_batch_matrix_multiply_nc_f32(
  2284. xnn_operator_t batch_matrix_multiply_op,
  2285. void* workspace,
  2286. const float* lhs_input,
  2287. const float* rhs_input,
  2288. float* output);
  2289. enum xnn_status xnn_create_ceiling_nc_f16(
  2290. uint32_t flags,
  2291. xnn_operator_t* ceiling_op_out);
  2292. enum xnn_status xnn_reshape_ceiling_nc_f16(
  2293. xnn_operator_t ceiling_op,
  2294. size_t batch_size,
  2295. size_t channels,
  2296. size_t input_stride,
  2297. size_t output_stride,
  2298. pthreadpool_t threadpool);
  2299. enum xnn_status xnn_setup_ceiling_nc_f16(
  2300. xnn_operator_t ceiling_op,
  2301. const void* input,
  2302. void* output);
  2303. enum xnn_status xnn_create_ceiling_nc_f32(
  2304. uint32_t flags,
  2305. xnn_operator_t* ceiling_op_out);
  2306. enum xnn_status xnn_run_ceiling_nc_f32(
  2307. size_t channels,
  2308. size_t input_stride,
  2309. size_t output_stride,
  2310. size_t batch_size,
  2311. const float* input,
  2312. float* output,
  2313. uint32_t flags,
  2314. pthreadpool_t threadpool);
  2315. enum xnn_status xnn_reshape_ceiling_nc_f32(
  2316. xnn_operator_t ceiling_op,
  2317. size_t batch_size,
  2318. size_t channels,
  2319. size_t input_stride,
  2320. size_t output_stride,
  2321. pthreadpool_t threadpool);
  2322. enum xnn_status xnn_setup_ceiling_nc_f32(
  2323. xnn_operator_t ceiling_op,
  2324. const float* input,
  2325. float* output);
  2326. enum xnn_status xnn_create_channel_shuffle_nc_x8(
  2327. size_t groups,
  2328. size_t group_channels,
  2329. size_t input_stride,
  2330. size_t output_stride,
  2331. uint32_t flags,
  2332. xnn_operator_t* channel_shuffle_op_out);
  2333. enum xnn_status xnn_reshape_channel_shuffle_nc_x8(
  2334. xnn_operator_t channel_shuffle_op,
  2335. size_t batch_size,
  2336. pthreadpool_t threadpool);
  2337. enum xnn_status xnn_setup_channel_shuffle_nc_x8(
  2338. xnn_operator_t channel_shuffle_op,
  2339. const void* input,
  2340. void* output);
  2341. enum xnn_status xnn_create_channel_shuffle_nc_x32(
  2342. size_t groups,
  2343. size_t group_channels,
  2344. size_t input_stride,
  2345. size_t output_stride,
  2346. uint32_t flags,
  2347. xnn_operator_t* channel_shuffle_op_out);
  2348. enum xnn_status xnn_reshape_channel_shuffle_nc_x32(
  2349. xnn_operator_t channel_shuffle_op,
  2350. size_t batch_size,
  2351. pthreadpool_t threadpool);
  2352. enum xnn_status xnn_setup_channel_shuffle_nc_x32(
  2353. xnn_operator_t channel_shuffle_op,
  2354. const void* input,
  2355. void* output);
  2356. enum xnn_status xnn_create_clamp_nc_f16(
  2357. float output_min,
  2358. float output_max,
  2359. uint32_t flags,
  2360. xnn_operator_t* clamp_op_out);
  2361. enum xnn_status xnn_reshape_clamp_nc_f16(
  2362. xnn_operator_t clamp_op,
  2363. size_t batch_size,
  2364. size_t channels,
  2365. size_t input_stride,
  2366. size_t output_stride,
  2367. pthreadpool_t threadpool);
  2368. enum xnn_status xnn_setup_clamp_nc_f16(
  2369. xnn_operator_t clamp_op,
  2370. const void* input,
  2371. void* output);
  2372. enum xnn_status xnn_create_clamp_nc_f32(
  2373. float output_min,
  2374. float output_max,
  2375. uint32_t flags,
  2376. xnn_operator_t* clamp_op_out);
  2377. enum xnn_status xnn_reshape_clamp_nc_f32(
  2378. xnn_operator_t clamp_op,
  2379. size_t batch_size,
  2380. size_t channels,
  2381. size_t input_stride,
  2382. size_t output_stride,
  2383. pthreadpool_t threadpool);
  2384. enum xnn_status xnn_setup_clamp_nc_f32(
  2385. xnn_operator_t clamp_op,
  2386. const float* input,
  2387. float* output);
  2388. enum xnn_status xnn_run_clamp_nc_f32(
  2389. size_t channels,
  2390. size_t input_stride,
  2391. size_t output_stride,
  2392. size_t batch_size,
  2393. const float* input,
  2394. float* output,
  2395. float output_min,
  2396. float output_max,
  2397. uint32_t flags,
  2398. pthreadpool_t threadpool);
  2399. enum xnn_status xnn_create_clamp_nc_s8(
  2400. int8_t output_min,
  2401. int8_t output_max,
  2402. uint32_t flags,
  2403. xnn_operator_t* clamp_op_out);
  2404. enum xnn_status xnn_reshape_clamp_nc_s8(
  2405. xnn_operator_t clamp_op,
  2406. size_t batch_size,
  2407. size_t channels,
  2408. size_t input_stride,
  2409. size_t output_stride,
  2410. pthreadpool_t threadpool);
  2411. enum xnn_status xnn_setup_clamp_nc_s8(
  2412. xnn_operator_t clamp_op,
  2413. const int8_t* input,
  2414. int8_t* output);
  2415. enum xnn_status xnn_create_clamp_nc_u8(
  2416. uint8_t output_min,
  2417. uint8_t output_max,
  2418. uint32_t flags,
  2419. xnn_operator_t* clamp_op_out);
  2420. enum xnn_status xnn_reshape_clamp_nc_u8(
  2421. xnn_operator_t clamp_op,
  2422. size_t batch_size,
  2423. size_t channels,
  2424. size_t input_stride,
  2425. size_t output_stride,
  2426. pthreadpool_t threadpool);
  2427. enum xnn_status xnn_setup_clamp_nc_u8(
  2428. xnn_operator_t clamp_op,
  2429. const uint8_t* input,
  2430. uint8_t* output);
  2431. enum xnn_status xnn_create_constant_pad_nd_x8(
  2432. const void* padding_value,
  2433. uint32_t flags,
  2434. xnn_operator_t* constant_pad_op_out);
  2435. enum xnn_status xnn_reshape_constant_pad_nd_x8(
  2436. xnn_operator_t constant_pad_op,
  2437. size_t num_dims,
  2438. const size_t* input_shape,
  2439. const size_t* pre_padding,
  2440. const size_t* post_padding,
  2441. pthreadpool_t threadpool);
  2442. enum xnn_status xnn_setup_constant_pad_nd_x8(
  2443. xnn_operator_t constant_pad_op,
  2444. const void* input,
  2445. void* output);
  2446. enum xnn_status xnn_run_constant_pad_nd_x8(
  2447. uint32_t flags,
  2448. size_t num_dims,
  2449. const size_t* input_shape,
  2450. const size_t* pre_paddings,
  2451. const size_t* post_paddings,
  2452. const void* input,
  2453. void* output,
  2454. const void* padding_value,
  2455. pthreadpool_t threadpool);
  2456. enum xnn_status xnn_create_constant_pad_nd_x16(
  2457. const void* padding_value,
  2458. uint32_t flags,
  2459. xnn_operator_t* constant_pad_op_out);
  2460. enum xnn_status xnn_reshape_constant_pad_nd_x16(
  2461. xnn_operator_t constant_pad_op,
  2462. size_t num_dims,
  2463. const size_t* input_shape,
  2464. const size_t* pre_padding,
  2465. const size_t* post_padding,
  2466. pthreadpool_t threadpool);
  2467. enum xnn_status xnn_setup_constant_pad_nd_x16(
  2468. xnn_operator_t constant_pad_op,
  2469. const void* input,
  2470. void* output);
  2471. enum xnn_status xnn_run_constant_pad_nd_x16(
  2472. uint32_t flags,
  2473. size_t num_dims,
  2474. const size_t* input_shape,
  2475. const size_t* pre_paddings,
  2476. const size_t* post_paddings,
  2477. const void* input,
  2478. void* output,
  2479. const void* padding_value,
  2480. pthreadpool_t threadpool);
  2481. enum xnn_status xnn_create_constant_pad_nd_x32(
  2482. const void* padding_value,
  2483. uint32_t flags,
  2484. xnn_operator_t* constant_pad_op_out);
  2485. enum xnn_status xnn_reshape_constant_pad_nd_x32(
  2486. xnn_operator_t constant_pad_op,
  2487. size_t num_dims,
  2488. const size_t* input_shape,
  2489. const size_t* pre_padding,
  2490. const size_t* post_padding,
  2491. pthreadpool_t threadpool);
  2492. enum xnn_status xnn_setup_constant_pad_nd_x32(
  2493. xnn_operator_t constant_pad_op,
  2494. const void* input,
  2495. void* output);
  2496. enum xnn_status xnn_run_constant_pad_nd_x32(
  2497. uint32_t flags,
  2498. size_t num_dims,
  2499. const size_t* input_shape,
  2500. const size_t* pre_paddings,
  2501. const size_t* post_paddings,
  2502. const void* input,
  2503. void* output,
  2504. const void* padding_value,
  2505. pthreadpool_t threadpool);
  2506. enum xnn_status xnn_create_convert_nc_f16_f32(
  2507. uint32_t flags,
  2508. xnn_operator_t* convert_op_out);
  2509. enum xnn_status xnn_reshape_convert_nc_f16_f32(
  2510. xnn_operator_t convert_op,
  2511. size_t batch_size,
  2512. size_t channels,
  2513. size_t input_stride,
  2514. size_t output_stride,
  2515. pthreadpool_t threadpool);
  2516. enum xnn_status xnn_setup_convert_nc_f16_f32(
  2517. xnn_operator_t convert_op,
  2518. const void* input,
  2519. float* output);
  2520. enum xnn_status xnn_run_convert_nc_f16_f32(
  2521. size_t channels,
  2522. size_t input_stride,
  2523. size_t output_stride,
  2524. size_t batch_size,
  2525. const void* input,
  2526. float* output,
  2527. uint32_t flags,
  2528. pthreadpool_t threadpool);
  2529. enum xnn_status xnn_create_convert_nc_f16_qd8(
  2530. uint32_t flags,
  2531. xnn_operator_t* convert_op_out);
  2532. enum xnn_status xnn_reshape_convert_nc_f16_qd8(
  2533. xnn_operator_t convert_op,
  2534. size_t batch_size,
  2535. size_t channels,
  2536. size_t input_stride,
  2537. size_t output_stride,
  2538. pthreadpool_t threadpool);
  2539. // quantization_params must be padded with at least XNN_EXTRA_QUANTIZATION_PARAMS entries.
  2540. enum xnn_status xnn_setup_convert_nc_f16_qd8(
  2541. xnn_operator_t convert_op,
  2542. const void* input,
  2543. int8_t* output,
  2544. struct xnn_dynamic_quantization_params* quantization_params);
  2545. enum xnn_status xnn_create_convert_nc_f32_qd8(
  2546. uint32_t flags,
  2547. xnn_operator_t* convert_op_out);
  2548. enum xnn_status xnn_reshape_convert_nc_f32_qd8(
  2549. xnn_operator_t convert_op,
  2550. size_t batch_size,
  2551. size_t channels,
  2552. size_t input_stride,
  2553. size_t output_stride,
  2554. pthreadpool_t threadpool);
  2555. // quantization_params must be padded with at least XNN_EXTRA_QUANTIZATION_PARAMS entries.
  2556. enum xnn_status xnn_setup_convert_nc_f32_qd8(
  2557. xnn_operator_t convert_op,
  2558. const float* input,
  2559. int8_t* output,
  2560. struct xnn_dynamic_quantization_params* quantization_params);
  2561. enum xnn_status xnn_create_convert_nc_f32_f16(
  2562. uint32_t flags,
  2563. xnn_operator_t* convert_op_out);
  2564. enum xnn_status xnn_reshape_convert_nc_f32_f16(
  2565. xnn_operator_t convert_op,
  2566. size_t batch_size,
  2567. size_t channels,
  2568. size_t input_stride,
  2569. size_t output_stride,
  2570. pthreadpool_t threadpool);
  2571. enum xnn_status xnn_setup_convert_nc_f32_f16(
  2572. xnn_operator_t convert_op,
  2573. const float* input,
  2574. void* output);
  2575. enum xnn_status xnn_run_convert_nc_f32_f16(
  2576. size_t channels,
  2577. size_t input_stride,
  2578. size_t output_stride,
  2579. size_t batch_size,
  2580. const float* input,
  2581. void* output,
  2582. uint32_t flags,
  2583. pthreadpool_t threadpool);
  2584. enum xnn_status xnn_create_convert_nc_f32_qs8(
  2585. float output_scale,
  2586. int8_t output_zero_point,
  2587. int8_t output_min,
  2588. int8_t output_max,
  2589. uint32_t flags,
  2590. xnn_operator_t* convert_op_out);
  2591. enum xnn_status xnn_reshape_convert_nc_f32_qs8(
  2592. xnn_operator_t convert_op,
  2593. size_t batch_size,
  2594. size_t channels,
  2595. size_t input_stride,
  2596. size_t output_stride,
  2597. pthreadpool_t threadpool);
  2598. enum xnn_status xnn_setup_convert_nc_f32_qs8(
  2599. xnn_operator_t convert_op,
  2600. const float* input,
  2601. int8_t* output);
  2602. enum xnn_status xnn_run_convert_nc_f32_qs8(
  2603. size_t channels,
  2604. size_t input_stride,
  2605. size_t output_stride,
  2606. size_t batch_size,
  2607. const float* input,
  2608. int8_t* output,
  2609. float output_scale,
  2610. int8_t output_zero_point,
  2611. uint32_t flags,
  2612. pthreadpool_t threadpool);
  2613. enum xnn_status xnn_create_convert_nc_f32_qu8(
  2614. float output_scale,
  2615. uint8_t output_zero_point,
  2616. uint8_t output_min,
  2617. uint8_t output_max,
  2618. uint32_t flags,
  2619. xnn_operator_t* convert_op_out);
  2620. enum xnn_status xnn_reshape_convert_nc_f32_qu8(
  2621. xnn_operator_t convert_op,
  2622. size_t batch_size,
  2623. size_t channels,
  2624. size_t input_stride,
  2625. size_t output_stride,
  2626. pthreadpool_t threadpool);
  2627. enum xnn_status xnn_setup_convert_nc_f32_qu8(
  2628. xnn_operator_t convert_op,
  2629. const float* input,
  2630. uint8_t* output);
  2631. enum xnn_status xnn_run_convert_nc_f32_qu8(
  2632. size_t channels,
  2633. size_t input_stride,
  2634. size_t output_stride,
  2635. size_t batch_size,
  2636. const float* input,
  2637. uint8_t* output,
  2638. float output_scale,
  2639. uint8_t output_zero_point,
  2640. uint32_t flags,
  2641. pthreadpool_t threadpool);
  2642. enum xnn_status xnn_create_convert_nc_qs8(
  2643. float input_scale,
  2644. int8_t input_zero_point,
  2645. float output_scale,
  2646. int8_t output_zero_point,
  2647. uint32_t flags,
  2648. xnn_operator_t* convert_op_out);
  2649. enum xnn_status xnn_reshape_convert_nc_qs8(
  2650. xnn_operator_t convert_op,
  2651. size_t batch_size,
  2652. size_t channels,
  2653. size_t input_stride,
  2654. size_t output_stride,
  2655. pthreadpool_t threadpool);
  2656. enum xnn_status xnn_setup_convert_nc_qs8(
  2657. xnn_operator_t convert_op,
  2658. const int8_t* input,
  2659. int8_t* output);
  2660. enum xnn_status xnn_create_convert_nc_qs8_f16(
  2661. float input_scale,
  2662. int8_t input_zero_point,
  2663. uint32_t flags,
  2664. xnn_operator_t* convert_op_out);
  2665. enum xnn_status xnn_reshape_convert_nc_qs8_f16(
  2666. xnn_operator_t convert_op,
  2667. size_t batch_size,
  2668. size_t channels,
  2669. size_t input_stride,
  2670. size_t output_stride,
  2671. pthreadpool_t threadpool);
  2672. enum xnn_status xnn_setup_convert_nc_qs8_f16(
  2673. xnn_operator_t convert_op,
  2674. const int8_t* input,
  2675. void* output);
  2676. enum xnn_status xnn_create_convert_nc_qs8_f32(
  2677. float input_scale,
  2678. int8_t input_zero_point,
  2679. uint32_t flags,
  2680. xnn_operator_t* convert_op_out);
  2681. enum xnn_status xnn_reshape_convert_nc_qs8_f32(
  2682. xnn_operator_t convert_op,
  2683. size_t batch_size,
  2684. size_t channels,
  2685. size_t input_stride,
  2686. size_t output_stride,
  2687. pthreadpool_t threadpool);
  2688. enum xnn_status xnn_setup_convert_nc_qs8_f32(
  2689. xnn_operator_t convert_op,
  2690. const int8_t* input,
  2691. float* output);
  2692. enum xnn_status xnn_run_convert_nc_qs8_f32(
  2693. size_t channels,
  2694. size_t input_stride,
  2695. size_t output_stride,
  2696. size_t batch_size,
  2697. const int8_t* input,
  2698. float* output,
  2699. float input_scale,
  2700. int8_t input_zero_point,
  2701. uint32_t flags,
  2702. pthreadpool_t threadpool);
  2703. enum xnn_status xnn_create_convert_nc_qs16_qs8(
  2704. float input_scale,
  2705. float output_scale,
  2706. int8_t output_zero_point,
  2707. uint32_t flags,
  2708. xnn_operator_t* convert_op_out);
  2709. enum xnn_status xnn_reshape_convert_nc_qs16_qs8(
  2710. xnn_operator_t convert_op,
  2711. size_t batch_size,
  2712. size_t channels,
  2713. size_t input_stride,
  2714. size_t output_stride,
  2715. pthreadpool_t threadpool);
  2716. enum xnn_status xnn_setup_convert_nc_qs16_qs8(
  2717. xnn_operator_t convert_op,
  2718. const int16_t* input,
  2719. int8_t* output);
  2720. enum xnn_status xnn_run_convert_nc_qs16_qs8(
  2721. size_t channels,
  2722. size_t input_stride,
  2723. size_t output_stride,
  2724. size_t batch_size,
  2725. const int16_t* input,
  2726. int8_t* output,
  2727. float input_scale,
  2728. float output_scale,
  2729. int8_t output_zero_point,
  2730. uint32_t flags,
  2731. pthreadpool_t threadpool);
  2732. enum xnn_status xnn_create_convert_nc_qu8(
  2733. float input_scale,
  2734. uint8_t input_zero_point,
  2735. float output_scale,
  2736. uint8_t output_zero_point,
  2737. uint32_t flags,
  2738. xnn_operator_t* convert_op_out);
  2739. enum xnn_status xnn_reshape_convert_nc_qu8(
  2740. xnn_operator_t convert_op,
  2741. size_t batch_size,
  2742. size_t channels,
  2743. size_t input_stride,
  2744. size_t output_stride,
  2745. pthreadpool_t threadpool);
  2746. enum xnn_status xnn_setup_convert_nc_qu8(
  2747. xnn_operator_t convert_op,
  2748. const uint8_t* input,
  2749. uint8_t* output);
  2750. enum xnn_status xnn_create_convert_nc_qu8_f32(
  2751. float input_scale,
  2752. uint8_t input_zero_point,
  2753. uint32_t flags,
  2754. xnn_operator_t* convert_op_out);
  2755. enum xnn_status xnn_reshape_convert_nc_qu8_f32(
  2756. xnn_operator_t convert_op,
  2757. size_t batch_size,
  2758. size_t channels,
  2759. size_t input_stride,
  2760. size_t output_stride,
  2761. pthreadpool_t threadpool);
  2762. enum xnn_status xnn_setup_convert_nc_qu8_f32(
  2763. xnn_operator_t convert_op,
  2764. const uint8_t* input,
  2765. float* output);
  2766. enum xnn_status xnn_run_convert_nc_qu8_f32(
  2767. size_t channels,
  2768. size_t input_stride,
  2769. size_t output_stride,
  2770. size_t batch_size,
  2771. const uint8_t* input,
  2772. float* output,
  2773. float input_scale,
  2774. uint8_t input_zero_point,
  2775. uint32_t flags,
  2776. pthreadpool_t threadpool);
  2777. enum xnn_status xnn_create_convolution2d_nchw_f16(
  2778. uint32_t input_padding_top,
  2779. uint32_t input_padding_right,
  2780. uint32_t input_padding_bottom,
  2781. uint32_t input_padding_left,
  2782. uint32_t kernel_height,
  2783. uint32_t kernel_width,
  2784. uint32_t subsampling_height,
  2785. uint32_t subsampling_width,
  2786. uint32_t dilation_height,
  2787. uint32_t dilation_width,
  2788. uint32_t groups,
  2789. size_t group_input_channels,
  2790. size_t group_output_channels,
  2791. size_t input_channel_stride,
  2792. size_t output_channel_stride,
  2793. const void* kernel,
  2794. const void* bias,
  2795. float output_min,
  2796. float output_max,
  2797. uint32_t flags,
  2798. xnn_code_cache_t code_cache,
  2799. xnn_weights_cache_t weights_cache,
  2800. xnn_operator_t* convolution_op_out);
  2801. enum xnn_status xnn_reshape_convolution2d_nchw_f16(
  2802. xnn_operator_t convolution_op,
  2803. size_t batch_size,
  2804. size_t input_height,
  2805. size_t input_width,
  2806. size_t* output_height_out,
  2807. size_t* output_width_out,
  2808. pthreadpool_t threadpool);
  2809. enum xnn_status xnn_setup_convolution2d_nchw_f16(
  2810. xnn_operator_t convolution_op,
  2811. const void* input,
  2812. void* output);
  2813. enum xnn_status xnn_create_convolution2d_nchw_f32(
  2814. uint32_t input_padding_top,
  2815. uint32_t input_padding_right,
  2816. uint32_t input_padding_bottom,
  2817. uint32_t input_padding_left,
  2818. uint32_t kernel_height,
  2819. uint32_t kernel_width,
  2820. uint32_t subsampling_height,
  2821. uint32_t subsampling_width,
  2822. uint32_t dilation_height,
  2823. uint32_t dilation_width,
  2824. uint32_t groups,
  2825. size_t group_input_channels,
  2826. size_t group_output_channels,
  2827. size_t input_channel_stride,
  2828. size_t output_channel_stride,
  2829. const float* kernel,
  2830. const float* bias,
  2831. float output_min,
  2832. float output_max,
  2833. uint32_t flags,
  2834. xnn_code_cache_t code_cache,
  2835. xnn_weights_cache_t weights_cache,
  2836. xnn_operator_t* convolution_op_out);
  2837. enum xnn_status xnn_reshape_convolution2d_nchw_f32(
  2838. xnn_operator_t convolution_op,
  2839. size_t batch_size,
  2840. size_t input_height,
  2841. size_t input_width,
  2842. size_t* output_height_out,
  2843. size_t* output_width_out,
  2844. pthreadpool_t threadpool);
  2845. enum xnn_status xnn_setup_convolution2d_nchw_f32(
  2846. xnn_operator_t convolution_op,
  2847. const float* input,
  2848. float* output);
  2849. enum xnn_status xnn_create_convolution2d_nhwc_f16(
  2850. uint32_t input_padding_top,
  2851. uint32_t input_padding_right,
  2852. uint32_t input_padding_bottom,
  2853. uint32_t input_padding_left,
  2854. uint32_t kernel_height,
  2855. uint32_t kernel_width,
  2856. uint32_t subsampling_height,
  2857. uint32_t subsampling_width,
  2858. uint32_t dilation_height,
  2859. uint32_t dilation_width,
  2860. uint32_t groups,
  2861. size_t group_input_channels,
  2862. size_t group_output_channels,
  2863. size_t input_channel_stride,
  2864. size_t output_channel_stride,
  2865. const void* kernel,
  2866. const void* bias,
  2867. float output_min,
  2868. float output_max,
  2869. uint32_t flags,
  2870. xnn_code_cache_t code_cache,
  2871. xnn_weights_cache_t weights_cache,
  2872. xnn_operator_t* convolution_op_out);
  2873. enum xnn_status xnn_reshape_convolution2d_nhwc_f16(
  2874. xnn_operator_t convolution_op,
  2875. size_t batch_size,
  2876. size_t input_height,
  2877. size_t input_width,
  2878. size_t* workspace_size,
  2879. size_t* workspace_alignment,
  2880. size_t* output_height_out,
  2881. size_t* output_width_out,
  2882. pthreadpool_t threadpool);
  2883. enum xnn_status xnn_setup_convolution2d_nhwc_f16(
  2884. xnn_operator_t convolution_op,
  2885. void* workspace,
  2886. const void* input,
  2887. void* output);
  2888. enum xnn_status xnn_create_convolution2d_nhwc_f32(
  2889. uint32_t input_padding_top,
  2890. uint32_t input_padding_right,
  2891. uint32_t input_padding_bottom,
  2892. uint32_t input_padding_left,
  2893. uint32_t kernel_height,
  2894. uint32_t kernel_width,
  2895. uint32_t subsampling_height,
  2896. uint32_t subsampling_width,
  2897. uint32_t dilation_height,
  2898. uint32_t dilation_width,
  2899. uint32_t groups,
  2900. size_t group_input_channels,
  2901. size_t group_output_channels,
  2902. size_t input_channel_stride,
  2903. size_t output_channel_stride,
  2904. const float* kernel,
  2905. const float* bias,
  2906. float output_min,
  2907. float output_max,
  2908. uint32_t flags,
  2909. xnn_code_cache_t code_cache,
  2910. xnn_weights_cache_t weights_cache,
  2911. xnn_operator_t* convolution_op_out);
  2912. // Forward declare.
  2913. struct xnn_post_operation;
  2914. /// Create a convolution operator with a number of post operations. The
  2915. /// convolution operator created using this function does not have output_min
  2916. /// and output_max. The list of operators in post_operations will be applied in
  2917. /// order. Convolution with post operations is only supported on JIT platforms
  2918. /// and when JIT is enabled.
  2919. enum xnn_status xnn_create_fused_convolution2d_nhwc_f32(
  2920. uint32_t input_padding_top,
  2921. uint32_t input_padding_right,
  2922. uint32_t input_padding_bottom,
  2923. uint32_t input_padding_left,
  2924. uint32_t kernel_height,
  2925. uint32_t kernel_width,
  2926. uint32_t subsampling_height,
  2927. uint32_t subsampling_width,
  2928. uint32_t dilation_height,
  2929. uint32_t dilation_width,
  2930. uint32_t groups,
  2931. size_t group_input_channels,
  2932. size_t group_output_channels,
  2933. size_t input_channel_stride,
  2934. size_t output_channel_stride,
  2935. const float* kernel,
  2936. const float* bias,
  2937. size_t num_post_operations,
  2938. struct xnn_post_operation* post_operations,
  2939. uint32_t flags,
  2940. xnn_code_cache_t code_cache,
  2941. xnn_weights_cache_t weights_cache,
  2942. xnn_operator_t* convolution_op_out);
  2943. enum xnn_status xnn_reshape_convolution2d_nhwc_f32(
  2944. xnn_operator_t convolution_op,
  2945. size_t batch_size,
  2946. size_t input_height,
  2947. size_t input_width,
  2948. size_t* workspace_size,
  2949. size_t* workspace_alignment,
  2950. size_t* output_height_out,
  2951. size_t* output_width_out,
  2952. pthreadpool_t threadpool);
  2953. enum xnn_status xnn_setup_convolution2d_nhwc_f32(
  2954. xnn_operator_t convolution_op,
  2955. void* workspace,
  2956. const float* input,
  2957. float* output);
  2958. enum xnn_status xnn_create_convolution2d_nhwc_qd8_f16_qc8w(
  2959. uint32_t input_padding_top, uint32_t input_padding_right,
  2960. uint32_t input_padding_bottom, uint32_t input_padding_left,
  2961. uint32_t kernel_height, uint32_t kernel_width, uint32_t subsampling_height,
  2962. uint32_t subsampling_width, uint32_t dilation_height,
  2963. uint32_t dilation_width, uint32_t groups, size_t group_input_channels,
  2964. size_t group_output_channels, size_t input_channel_stride,
  2965. size_t output_channel_stride, const float* kernel_scale,
  2966. const int8_t* kernel, const float* bias, float output_min, float output_max,
  2967. uint32_t flags, xnn_code_cache_t code_cache,
  2968. xnn_weights_cache_t weights_cache, xnn_operator_t* convolution_op_out);
  2969. enum xnn_status xnn_create_convolution2d_nhwc_qd8_f32_qc8w(
  2970. uint32_t input_padding_top, uint32_t input_padding_right,
  2971. uint32_t input_padding_bottom, uint32_t input_padding_left,
  2972. uint32_t kernel_height, uint32_t kernel_width, uint32_t subsampling_height,
  2973. uint32_t subsampling_width, uint32_t dilation_height,
  2974. uint32_t dilation_width, uint32_t groups, size_t group_input_channels,
  2975. size_t group_output_channels, size_t input_channel_stride,
  2976. size_t output_channel_stride, const float* kernel_scale,
  2977. const int8_t* kernel, const float* bias, float output_min, float output_max,
  2978. uint32_t flags, xnn_code_cache_t code_cache,
  2979. xnn_weights_cache_t weights_cache, xnn_operator_t* convolution_op_out);
  2980. enum xnn_status xnn_create_convolution2d_nhwc_qs8(
  2981. uint32_t input_padding_top,
  2982. uint32_t input_padding_right,
  2983. uint32_t input_padding_bottom,
  2984. uint32_t input_padding_left,
  2985. uint32_t kernel_height,
  2986. uint32_t kernel_width,
  2987. uint32_t subsampling_height,
  2988. uint32_t subsampling_width,
  2989. uint32_t dilation_height,
  2990. uint32_t dilation_width,
  2991. uint32_t groups,
  2992. size_t group_input_channels,
  2993. size_t group_output_channels,
  2994. size_t input_channel_stride,
  2995. size_t output_channel_stride,
  2996. int8_t input_zero_point,
  2997. float input_scale,
  2998. float kernel_scale,
  2999. const int8_t* kernel,
  3000. const int32_t* bias,
  3001. int8_t output_zero_point,
  3002. float output_scale,
  3003. int8_t output_min,
  3004. int8_t output_max,
  3005. uint32_t flags,
  3006. xnn_code_cache_t code_cache,
  3007. xnn_weights_cache_t weights_cache,
  3008. xnn_operator_t* convolution_op_out);
  3009. enum xnn_status xnn_reshape_convolution2d_nhwc_qd8_f16_qc8w(
  3010. xnn_operator_t convolution_op, size_t batch_size, size_t input_height,
  3011. size_t input_width, size_t* workspace_size, size_t* workspace_alignment,
  3012. size_t* output_height_out, size_t* output_width_out,
  3013. pthreadpool_t threadpool);
  3014. enum xnn_status xnn_reshape_convolution2d_nhwc_qd8_f32_qc8w(
  3015. xnn_operator_t convolution_op, size_t batch_size, size_t input_height,
  3016. size_t input_width, size_t* workspace_size, size_t* workspace_alignment,
  3017. size_t* output_height_out, size_t* output_width_out,
  3018. pthreadpool_t threadpool);
  3019. enum xnn_status xnn_reshape_convolution2d_nhwc_qs8(
  3020. xnn_operator_t convolution_op,
  3021. size_t batch_size,
  3022. size_t input_height,
  3023. size_t input_width,
  3024. size_t* workspace_size,
  3025. size_t* workspace_alignment,
  3026. size_t* output_height_out,
  3027. size_t* output_width_out,
  3028. pthreadpool_t threadpool);
  3029. enum xnn_status xnn_setup_convolution2d_nhwc_qd8_f16_qc8w(
  3030. xnn_operator_t convolution_op, void* workspace, const int8_t* input,
  3031. void* output,
  3032. const struct xnn_dynamic_quantization_params* quantization_params);
  3033. enum xnn_status xnn_setup_convolution2d_nhwc_qd8_f32_qc8w(
  3034. xnn_operator_t convolution_op, void* workspace, const int8_t* input,
  3035. float* output,
  3036. const struct xnn_dynamic_quantization_params* quantization_params);
  3037. enum xnn_status xnn_setup_convolution2d_nhwc_qs8(
  3038. xnn_operator_t convolution_op,
  3039. void* workspace,
  3040. const int8_t* input,
  3041. int8_t* output);
  3042. enum xnn_status xnn_create_convolution2d_nhwc_qs8_qc8w(
  3043. uint32_t input_padding_top,
  3044. uint32_t input_padding_right,
  3045. uint32_t input_padding_bottom,
  3046. uint32_t input_padding_left,
  3047. uint32_t kernel_height,
  3048. uint32_t kernel_width,
  3049. uint32_t subsampling_height,
  3050. uint32_t subsampling_width,
  3051. uint32_t dilation_height,
  3052. uint32_t dilation_width,
  3053. uint32_t groups,
  3054. size_t group_input_channels,
  3055. size_t group_output_channels,
  3056. size_t input_channel_stride,
  3057. size_t output_channel_stride,
  3058. int8_t input_zero_point,
  3059. float input_scale,
  3060. const float* kernel_scale,
  3061. const int8_t* kernel,
  3062. const int32_t* bias,
  3063. int8_t output_zero_point,
  3064. float output_scale,
  3065. int8_t output_min,
  3066. int8_t output_max,
  3067. uint32_t flags,
  3068. xnn_code_cache_t code_cache,
  3069. xnn_weights_cache_t weights_cache,
  3070. xnn_operator_t* convolution_op_out);
  3071. enum xnn_status xnn_reshape_convolution2d_nhwc_qs8_qc8w(
  3072. xnn_operator_t convolution_op,
  3073. size_t batch_size,
  3074. size_t input_height,
  3075. size_t input_width,
  3076. size_t* workspace_size,
  3077. size_t* workspace_alignment,
  3078. size_t* output_height_out,
  3079. size_t* output_width_out,
  3080. pthreadpool_t threadpool);
  3081. enum xnn_status xnn_setup_convolution2d_nhwc_qs8_qc8w(
  3082. xnn_operator_t convolution_op,
  3083. void* workspace,
  3084. const int8_t* input,
  3085. int8_t* output);
  3086. enum xnn_status xnn_create_convolution2d_nhwc_qu8(
  3087. uint32_t input_padding_top,
  3088. uint32_t input_padding_right,
  3089. uint32_t input_padding_bottom,
  3090. uint32_t input_padding_left,
  3091. uint32_t kernel_height,
  3092. uint32_t kernel_width,
  3093. uint32_t subsampling_height,
  3094. uint32_t subsampling_width,
  3095. uint32_t dilation_height,
  3096. uint32_t dilation_width,
  3097. uint32_t groups,
  3098. size_t group_input_channels,
  3099. size_t group_output_channels,
  3100. size_t input_channel_stride,
  3101. size_t output_channel_stride,
  3102. uint8_t input_zero_point,
  3103. float input_scale,
  3104. uint8_t kernel_zero_point,
  3105. float kernel_scale,
  3106. const uint8_t* kernel,
  3107. const int32_t* bias,
  3108. uint8_t output_zero_point,
  3109. float output_scale,
  3110. uint8_t output_min,
  3111. uint8_t output_max,
  3112. uint32_t flags,
  3113. xnn_code_cache_t code_cache,
  3114. xnn_weights_cache_t weights_cache,
  3115. xnn_operator_t* convolution_op_out);
  3116. enum xnn_status xnn_reshape_convolution2d_nhwc_qu8(
  3117. xnn_operator_t convolution_op,
  3118. size_t batch_size,
  3119. size_t input_height,
  3120. size_t input_width,
  3121. size_t* workspace_size,
  3122. size_t* workspace_alignment,
  3123. size_t* output_height_out,
  3124. size_t* output_width_out,
  3125. pthreadpool_t threadpool);
  3126. enum xnn_status xnn_setup_convolution2d_nhwc_qu8(
  3127. xnn_operator_t convolution_op,
  3128. void* workspace,
  3129. const uint8_t* input,
  3130. uint8_t* output);
  3131. enum xnn_status xnn_create_copy_nc_x8(
  3132. uint32_t flags,
  3133. xnn_operator_t* copy_op_out);
  3134. enum xnn_status xnn_reshape_copy_nc_x8(
  3135. xnn_operator_t copy_op,
  3136. size_t batch_size,
  3137. size_t channels,
  3138. size_t input_stride,
  3139. size_t output_stride,
  3140. pthreadpool_t threadpool);
  3141. enum xnn_status xnn_setup_copy_nc_x8(
  3142. xnn_operator_t copy_op,
  3143. const void* input,
  3144. void* output);
  3145. enum xnn_status xnn_create_copy_nc_x16(
  3146. uint32_t flags,
  3147. xnn_operator_t* copy_op_out);
  3148. enum xnn_status xnn_reshape_copy_nc_x16(
  3149. xnn_operator_t copy_op,
  3150. size_t batch_size,
  3151. size_t channels,
  3152. size_t input_stride,
  3153. size_t output_stride,
  3154. pthreadpool_t threadpool);
  3155. enum xnn_status xnn_setup_copy_nc_x16(
  3156. xnn_operator_t copy_op,
  3157. const void* input,
  3158. void* output);
  3159. enum xnn_status xnn_create_copy_nc_x32(
  3160. uint32_t flags,
  3161. xnn_operator_t* copy_op_out);
  3162. enum xnn_status xnn_reshape_copy_nc_x32(
  3163. xnn_operator_t copy_op,
  3164. size_t batch_size,
  3165. size_t channels,
  3166. size_t input_stride,
  3167. size_t output_stride,
  3168. pthreadpool_t threadpool);
  3169. enum xnn_status xnn_setup_copy_nc_x32(
  3170. xnn_operator_t copy_op,
  3171. const void* input,
  3172. void* output);
  3173. enum xnn_status xnn_run_copy_nc_x32(
  3174. size_t channels,
  3175. size_t input_stride,
  3176. size_t output_stride,
  3177. size_t batch_size,
  3178. const uint32_t* input,
  3179. uint32_t* output,
  3180. uint32_t flags,
  3181. pthreadpool_t threadpool);
  3182. enum xnn_status xnn_create_deconvolution2d_nhwc_f16(
  3183. uint32_t output_padding_top,
  3184. uint32_t output_padding_right,
  3185. uint32_t output_padding_bottom,
  3186. uint32_t output_padding_left,
  3187. uint32_t kernel_height,
  3188. uint32_t kernel_width,
  3189. uint32_t stride_height,
  3190. uint32_t stride_width,
  3191. uint32_t dilation_height,
  3192. uint32_t dilation_width,
  3193. uint32_t groups,
  3194. size_t group_input_channels,
  3195. size_t group_output_channels,
  3196. size_t input_pixel_stride,
  3197. size_t output_pixel_stride,
  3198. const void* kernel,
  3199. const void* bias,
  3200. float output_min,
  3201. float output_max,
  3202. uint32_t flags,
  3203. xnn_code_cache_t code_cache,
  3204. xnn_weights_cache_t weights_cache,
  3205. xnn_operator_t* deconvolution_op_out);
  3206. enum xnn_status xnn_reshape_deconvolution2d_nhwc_f16(
  3207. xnn_operator_t deconvolution_op,
  3208. size_t batch_size,
  3209. size_t input_height,
  3210. size_t input_width,
  3211. uint32_t adjustment_height,
  3212. uint32_t adjustment_width,
  3213. size_t* output_height_out,
  3214. size_t* output_width_out,
  3215. pthreadpool_t threadpool);
  3216. enum xnn_status xnn_setup_deconvolution2d_nhwc_f16(
  3217. xnn_operator_t deconvolution_op,
  3218. const void* input,
  3219. void* output);
  3220. enum xnn_status xnn_create_deconvolution2d_nhwc_f32(
  3221. uint32_t output_padding_top,
  3222. uint32_t output_padding_right,
  3223. uint32_t output_padding_bottom,
  3224. uint32_t output_padding_left,
  3225. uint32_t kernel_height,
  3226. uint32_t kernel_width,
  3227. uint32_t stride_height,
  3228. uint32_t stride_width,
  3229. uint32_t dilation_height,
  3230. uint32_t dilation_width,
  3231. uint32_t groups,
  3232. size_t group_input_channels,
  3233. size_t group_output_channels,
  3234. size_t input_pixel_stride,
  3235. size_t output_pixel_stride,
  3236. const float* kernel,
  3237. const float* bias,
  3238. float output_min,
  3239. float output_max,
  3240. uint32_t flags,
  3241. xnn_code_cache_t code_cache,
  3242. xnn_weights_cache_t weights_cache,
  3243. xnn_operator_t* deconvolution_op_out);
  3244. enum xnn_status xnn_reshape_deconvolution2d_nhwc_f32(
  3245. xnn_operator_t deconvolution_op,
  3246. size_t batch_size,
  3247. size_t input_height,
  3248. size_t input_width,
  3249. uint32_t adjustment_height,
  3250. uint32_t adjustment_width,
  3251. size_t* output_height_out,
  3252. size_t* output_width_out,
  3253. pthreadpool_t threadpool);
  3254. enum xnn_status xnn_setup_deconvolution2d_nhwc_f32(
  3255. xnn_operator_t deconvolution_op,
  3256. const float* input,
  3257. float* output);
  3258. enum xnn_status xnn_create_deconvolution2d_nhwc_qs8(
  3259. uint32_t output_padding_top,
  3260. uint32_t output_padding_right,
  3261. uint32_t output_padding_bottom,
  3262. uint32_t output_padding_left,
  3263. uint32_t kernel_height,
  3264. uint32_t kernel_width,
  3265. uint32_t stride_height,
  3266. uint32_t stride_width,
  3267. uint32_t dilation_height,
  3268. uint32_t dilation_width,
  3269. uint32_t groups,
  3270. size_t group_input_channels,
  3271. size_t group_output_channels,
  3272. size_t input_pixel_stride,
  3273. size_t output_pixel_stride,
  3274. int8_t input_zero_point,
  3275. float input_scale,
  3276. float kernel_scale,
  3277. const int8_t* kernel,
  3278. const int32_t* bias,
  3279. int8_t output_zero_point,
  3280. float output_scale,
  3281. int8_t output_min,
  3282. int8_t output_max,
  3283. uint32_t flags,
  3284. xnn_code_cache_t code_cache,
  3285. xnn_weights_cache_t weights_cache,
  3286. xnn_operator_t* deconvolution_op_out);
  3287. enum xnn_status xnn_reshape_deconvolution2d_nhwc_qs8(
  3288. xnn_operator_t deconvolution_op,
  3289. size_t batch_size,
  3290. size_t input_height,
  3291. size_t input_width,
  3292. uint32_t adjustment_height,
  3293. uint32_t adjustment_width,
  3294. size_t* output_height_out,
  3295. size_t* output_width_out,
  3296. pthreadpool_t threadpool);
  3297. enum xnn_status xnn_setup_deconvolution2d_nhwc_qs8(
  3298. xnn_operator_t deconvolution_op,
  3299. const int8_t* input,
  3300. int8_t* output);
  3301. enum xnn_status xnn_create_deconvolution2d_nhwc_qu8(
  3302. uint32_t output_padding_top,
  3303. uint32_t output_padding_right,
  3304. uint32_t output_padding_bottom,
  3305. uint32_t output_padding_left,
  3306. uint32_t kernel_height,
  3307. uint32_t kernel_width,
  3308. uint32_t stride_height,
  3309. uint32_t stride_width,
  3310. uint32_t dilation_height,
  3311. uint32_t dilation_width,
  3312. uint32_t groups,
  3313. size_t group_input_channels,
  3314. size_t group_output_channels,
  3315. size_t input_pixel_stride,
  3316. size_t output_pixel_stride,
  3317. uint8_t input_zero_point,
  3318. float input_scale,
  3319. uint8_t kernel_zero_point,
  3320. float kernel_scale,
  3321. const uint8_t* kernel,
  3322. const int32_t* bias,
  3323. uint8_t output_zero_point,
  3324. float output_scale,
  3325. uint8_t output_min,
  3326. uint8_t output_max,
  3327. uint32_t flags,
  3328. xnn_code_cache_t code_cache,
  3329. xnn_weights_cache_t weights_cache,
  3330. xnn_operator_t* deconvolution_op_out);
  3331. enum xnn_status xnn_reshape_deconvolution2d_nhwc_qu8(
  3332. xnn_operator_t deconvolution_op,
  3333. size_t batch_size,
  3334. size_t input_height,
  3335. size_t input_width,
  3336. uint32_t adjustment_height,
  3337. uint32_t adjustment_width,
  3338. size_t* output_height_out,
  3339. size_t* output_width_out,
  3340. pthreadpool_t threadpool);
  3341. enum xnn_status xnn_setup_deconvolution2d_nhwc_qu8(
  3342. xnn_operator_t deconvolution_op,
  3343. const uint8_t* input,
  3344. uint8_t* output);
  3345. enum xnn_status xnn_create_depth_to_space_nchw2nhwc_x16(
  3346. uint32_t block_size,
  3347. uint32_t flags,
  3348. xnn_operator_t* depth_to_space_op_out);
  3349. enum xnn_status xnn_reshape_depth_to_space_nchw2nhwc_x16(
  3350. xnn_operator_t depth_to_space_op,
  3351. size_t batch_size,
  3352. size_t input_height,
  3353. size_t input_width,
  3354. size_t input_channels,
  3355. size_t* output_height_out,
  3356. size_t* output_width_out,
  3357. size_t* output_channels_out,
  3358. pthreadpool_t threadpool);
  3359. enum xnn_status xnn_setup_depth_to_space_nchw2nhwc_x16(
  3360. xnn_operator_t depth_to_space_op,
  3361. const void* input,
  3362. void* output);
  3363. enum xnn_status xnn_create_depth_to_space_nchw2nhwc_x32(
  3364. uint32_t block_size,
  3365. uint32_t flags,
  3366. xnn_operator_t* depth_to_space_op_out);
  3367. enum xnn_status xnn_reshape_depth_to_space_nchw2nhwc_x32(
  3368. xnn_operator_t depth_to_space_op,
  3369. size_t batch_size,
  3370. size_t input_height,
  3371. size_t input_width,
  3372. size_t input_channels,
  3373. size_t* output_height_out,
  3374. size_t* output_width_out,
  3375. size_t* output_channels_out,
  3376. pthreadpool_t threadpool);
  3377. enum xnn_status xnn_setup_depth_to_space_nchw2nhwc_x32(
  3378. xnn_operator_t depth_to_space_op,
  3379. const void* input,
  3380. void* output);
  3381. enum xnn_status xnn_create_depth_to_space_nhwc_x8(
  3382. uint32_t block_size,
  3383. uint32_t flags,
  3384. xnn_operator_t* depth_to_space_op_out);
  3385. enum xnn_status xnn_reshape_depth_to_space_nhwc_x8(
  3386. xnn_operator_t depth_to_space_op,
  3387. size_t batch_size,
  3388. size_t input_height,
  3389. size_t input_width,
  3390. size_t input_channels,
  3391. size_t* output_height_out,
  3392. size_t* output_width_out,
  3393. size_t* output_channels_out,
  3394. pthreadpool_t threadpool);
  3395. enum xnn_status xnn_setup_depth_to_space_nhwc_x8(
  3396. xnn_operator_t depth_to_space_op,
  3397. const void* input,
  3398. void* output);
  3399. enum xnn_status xnn_create_depth_to_space_nhwc_x16(
  3400. uint32_t block_size,
  3401. uint32_t flags,
  3402. xnn_operator_t* depth_to_space_op_out);
  3403. enum xnn_status xnn_reshape_depth_to_space_nhwc_x16(
  3404. xnn_operator_t depth_to_space_op,
  3405. size_t batch_size,
  3406. size_t input_height,
  3407. size_t input_width,
  3408. size_t input_channels,
  3409. size_t* output_height_out,
  3410. size_t* output_width_out,
  3411. size_t* output_channels_out,
  3412. pthreadpool_t threadpool);
  3413. enum xnn_status xnn_setup_depth_to_space_nhwc_x16(
  3414. xnn_operator_t depth_to_space_op,
  3415. const void* input,
  3416. void* output);
  3417. enum xnn_status xnn_create_depth_to_space_nhwc_x32(
  3418. uint32_t block_size,
  3419. uint32_t flags,
  3420. xnn_operator_t* depth_to_space_op_out);
  3421. enum xnn_status xnn_reshape_depth_to_space_nhwc_x32(
  3422. xnn_operator_t depth_to_space_op,
  3423. size_t batch_size,
  3424. size_t input_height,
  3425. size_t input_width,
  3426. size_t input_channels,
  3427. size_t* output_height_out,
  3428. size_t* output_width_out,
  3429. size_t* output_channels_out,
  3430. pthreadpool_t threadpool);
  3431. enum xnn_status xnn_setup_depth_to_space_nhwc_x32(
  3432. xnn_operator_t depth_to_space_op,
  3433. const void* input,
  3434. void* output);
  3435. enum xnn_status xnn_create_divide_nd_f16(
  3436. float output_min,
  3437. float output_max,
  3438. uint32_t flags,
  3439. xnn_operator_t* divide_op_out);
  3440. enum xnn_status xnn_reshape_divide_nd_f16(
  3441. xnn_operator_t divide_op,
  3442. size_t num_input1_dims,
  3443. const size_t* input1_shape,
  3444. size_t num_input2_dims,
  3445. const size_t* input2_shape,
  3446. pthreadpool_t threadpool);
  3447. enum xnn_status xnn_setup_divide_nd_f16(
  3448. xnn_operator_t divide_op,
  3449. const void* input1,
  3450. const void* input2,
  3451. void* output);
  3452. enum xnn_status xnn_create_divide_nd_f32(
  3453. float output_min,
  3454. float output_max,
  3455. uint32_t flags,
  3456. xnn_operator_t* divide_op_out);
  3457. enum xnn_status xnn_reshape_divide_nd_f32(
  3458. xnn_operator_t divide_op,
  3459. size_t num_input1_dims,
  3460. const size_t* input1_shape,
  3461. size_t num_input2_dims,
  3462. const size_t* input2_shape,
  3463. pthreadpool_t threadpool);
  3464. enum xnn_status xnn_setup_divide_nd_f32(
  3465. xnn_operator_t divide_op,
  3466. const float* input1,
  3467. const float* input2,
  3468. float* output);
  3469. enum xnn_status xnn_run_divide_nd_f32(
  3470. size_t num_input1_dims,
  3471. const size_t* input1_shape,
  3472. size_t num_input2_dims,
  3473. const size_t* input2_shape,
  3474. const float* input1,
  3475. const float* input2,
  3476. float* output,
  3477. float output_min,
  3478. float output_max,
  3479. uint32_t flags,
  3480. pthreadpool_t threadpool);
  3481. enum xnn_status xnn_create_dynamic_fully_connected_nc_f16(
  3482. float output_min,
  3483. float output_max,
  3484. uint32_t flags,
  3485. xnn_operator_t* dynamic_fully_connected_op_out);
  3486. enum xnn_status xnn_reshape_dynamic_fully_connected_nc_f16(
  3487. xnn_operator_t dynamic_fully_connected_op,
  3488. size_t batch_size,
  3489. size_t input_channels,
  3490. size_t output_channels,
  3491. size_t input_stride,
  3492. size_t output_stride,
  3493. size_t* workspace_size,
  3494. size_t* workspace_alignment,
  3495. pthreadpool_t threadpool);
  3496. enum xnn_status xnn_setup_dynamic_fully_connected_nc_f16(
  3497. xnn_operator_t dynamic_fully_connected_op,
  3498. void* workspace,
  3499. const void* input,
  3500. const void* kernel,
  3501. const void* bias,
  3502. void* output);
  3503. enum xnn_status xnn_create_dynamic_fully_connected_nc_f32(
  3504. float output_min,
  3505. float output_max,
  3506. uint32_t flags,
  3507. xnn_operator_t* dynamic_fully_connected_op_out);
  3508. enum xnn_status xnn_reshape_dynamic_fully_connected_nc_f32(
  3509. xnn_operator_t dynamic_fully_connected_op,
  3510. size_t batch_size,
  3511. size_t input_channels,
  3512. size_t output_channels,
  3513. size_t input_stride,
  3514. size_t output_stride,
  3515. size_t* workspace_size,
  3516. size_t* workspace_alignment,
  3517. pthreadpool_t threadpool);
  3518. enum xnn_status xnn_setup_dynamic_fully_connected_nc_f32(
  3519. xnn_operator_t dynamic_fully_connected_op,
  3520. void* workspace,
  3521. const float* input,
  3522. const float* kernel,
  3523. const float* bias,
  3524. float* output);
  3525. enum xnn_status xnn_create_elu_nc_f16(
  3526. float alpha,
  3527. uint32_t flags,
  3528. xnn_operator_t* elu_op_out);
  3529. enum xnn_status xnn_reshape_elu_nc_f16(
  3530. xnn_operator_t elu_op,
  3531. size_t batch_size,
  3532. size_t channels,
  3533. size_t input_stride,
  3534. size_t output_stride,
  3535. pthreadpool_t threadpool);
  3536. enum xnn_status xnn_setup_elu_nc_f16(
  3537. xnn_operator_t elu_op,
  3538. const void* input,
  3539. void* output);
  3540. enum xnn_status xnn_create_elu_nc_f32(
  3541. float alpha,
  3542. uint32_t flags,
  3543. xnn_operator_t* elu_op_out);
  3544. enum xnn_status xnn_reshape_elu_nc_f32(
  3545. xnn_operator_t elu_op,
  3546. size_t batch_size,
  3547. size_t channels,
  3548. size_t input_stride,
  3549. size_t output_stride,
  3550. pthreadpool_t threadpool);
  3551. enum xnn_status xnn_setup_elu_nc_f32(
  3552. xnn_operator_t elu_op,
  3553. const float* input,
  3554. float* output);
  3555. enum xnn_status xnn_run_elu_nc_f32(
  3556. size_t channels,
  3557. size_t input_stride,
  3558. size_t output_stride,
  3559. size_t batch_size,
  3560. const float* input,
  3561. float* output,
  3562. float alpha,
  3563. uint32_t flags,
  3564. pthreadpool_t threadpool);
  3565. enum xnn_status xnn_create_elu_nc_qs8(
  3566. float alpha,
  3567. int8_t input_zero_point,
  3568. float input_scale,
  3569. int8_t output_zero_point,
  3570. float output_scale,
  3571. int8_t output_min,
  3572. int8_t output_max,
  3573. uint32_t flags,
  3574. xnn_operator_t* elu_op_out);
  3575. enum xnn_status xnn_reshape_elu_nc_qs8(
  3576. xnn_operator_t elu_op,
  3577. size_t batch_size,
  3578. size_t channels,
  3579. size_t input_stride,
  3580. size_t output_stride,
  3581. pthreadpool_t threadpool);
  3582. enum xnn_status xnn_setup_elu_nc_qs8(
  3583. xnn_operator_t elu_op,
  3584. const int8_t* input,
  3585. int8_t* output);
  3586. enum xnn_status xnn_create_floor_nc_f16(
  3587. uint32_t flags,
  3588. xnn_operator_t* floor_op_out);
  3589. enum xnn_status xnn_reshape_floor_nc_f16(
  3590. xnn_operator_t floor_op,
  3591. size_t batch_size,
  3592. size_t channels,
  3593. size_t input_stride,
  3594. size_t output_stride,
  3595. pthreadpool_t threadpool);
  3596. enum xnn_status xnn_setup_floor_nc_f16(
  3597. xnn_operator_t floor_op,
  3598. const void* input,
  3599. void* output);
  3600. enum xnn_status xnn_create_floor_nc_f32(
  3601. uint32_t flags,
  3602. xnn_operator_t* floor_op_out);
  3603. enum xnn_status xnn_reshape_floor_nc_f32(
  3604. xnn_operator_t floor_op,
  3605. size_t batch_size,
  3606. size_t channels,
  3607. size_t input_stride,
  3608. size_t output_stride,
  3609. pthreadpool_t threadpool);
  3610. enum xnn_status xnn_setup_floor_nc_f32(
  3611. xnn_operator_t floor_op,
  3612. const float* input,
  3613. float* output);
  3614. enum xnn_status xnn_run_floor_nc_f32(
  3615. size_t channels,
  3616. size_t input_stride,
  3617. size_t output_stride,
  3618. size_t batch_size,
  3619. const float* input,
  3620. float* output,
  3621. uint32_t flags,
  3622. pthreadpool_t threadpool);
  3623. enum xnn_status xnn_create_fully_connected_nc_f16(
  3624. size_t input_channels,
  3625. size_t output_channels,
  3626. size_t input_stride,
  3627. size_t output_stride,
  3628. const void* kernel,
  3629. const void* bias,
  3630. float output_min,
  3631. float output_max,
  3632. uint32_t flags,
  3633. xnn_code_cache_t code_cache,
  3634. xnn_weights_cache_t weights_cache,
  3635. xnn_operator_t* fully_connected_op_out);
  3636. enum xnn_status xnn_reshape_fully_connected_nc_f16(
  3637. xnn_operator_t fully_connected_op,
  3638. size_t batch_size,
  3639. pthreadpool_t threadpool);
  3640. enum xnn_status xnn_setup_fully_connected_nc_f16(
  3641. xnn_operator_t fully_connected_op,
  3642. const void* input,
  3643. void* output);
  3644. enum xnn_status xnn_create_fully_connected_nc_f32(
  3645. size_t input_channels,
  3646. size_t output_channels,
  3647. size_t input_stride,
  3648. size_t output_stride,
  3649. const float* kernel,
  3650. const float* bias,
  3651. float output_min,
  3652. float output_max,
  3653. uint32_t flags,
  3654. xnn_code_cache_t code_cache,
  3655. xnn_weights_cache_t weights_cache,
  3656. xnn_operator_t* fully_connected_op_out);
  3657. enum xnn_status xnn_reshape_fully_connected_nc_f32(
  3658. xnn_operator_t fully_connected_op,
  3659. size_t batch_size,
  3660. pthreadpool_t threadpool);
  3661. enum xnn_status xnn_setup_fully_connected_nc_f32(
  3662. xnn_operator_t fully_connected_op,
  3663. const float* input,
  3664. float* output);
  3665. enum xnn_status xnn_create_fully_connected_nc_f32_qc4w(
  3666. size_t input_channels,
  3667. size_t output_channels,
  3668. size_t input_stride,
  3669. size_t output_stride,
  3670. uint8_t kernel_zero_point,
  3671. const float* kernel_scale,
  3672. const uint8_t* kernel,
  3673. const float* bias,
  3674. float output_min,
  3675. float output_max,
  3676. uint32_t flags,
  3677. xnn_code_cache_t code_cache,
  3678. xnn_weights_cache_t weights_cache,
  3679. xnn_operator_t* fully_connected_op_out);
  3680. enum xnn_status xnn_reshape_fully_connected_nc_f32_qc4w(
  3681. xnn_operator_t fully_connected_op,
  3682. size_t batch_size,
  3683. pthreadpool_t threadpool);
  3684. enum xnn_status xnn_setup_fully_connected_nc_f32_qc4w(
  3685. xnn_operator_t fully_connected_op,
  3686. const float* input,
  3687. float* output);
  3688. enum xnn_status xnn_create_fully_connected_nc_f32_qc8w(
  3689. size_t input_channels,
  3690. size_t output_channels,
  3691. size_t input_stride,
  3692. size_t output_stride,
  3693. const float* kernel_scale,
  3694. const int8_t* kernel,
  3695. const float* bias,
  3696. float output_min,
  3697. float output_max,
  3698. uint32_t flags,
  3699. xnn_code_cache_t code_cache,
  3700. xnn_weights_cache_t weights_cache,
  3701. xnn_operator_t* fully_connected_op_out);
  3702. enum xnn_status xnn_reshape_fully_connected_nc_f32_qc8w(
  3703. xnn_operator_t fully_connected_op,
  3704. size_t batch_size,
  3705. pthreadpool_t threadpool);
  3706. enum xnn_status xnn_setup_fully_connected_nc_f32_qc8w(
  3707. xnn_operator_t fully_connected_op,
  3708. const float* input,
  3709. float* output);
  3710. enum xnn_status xnn_create_fully_connected_nc_qd8_f16_qc4w(
  3711. size_t input_channels,
  3712. size_t output_channels,
  3713. size_t input_stride,
  3714. size_t output_stride,
  3715. uint8_t kernel_zero_point,
  3716. const float* kernel_scale,
  3717. const void* kernel,
  3718. const float* bias,
  3719. float output_min,
  3720. float output_max,
  3721. uint32_t flags,
  3722. xnn_code_cache_t code_cache,
  3723. xnn_weights_cache_t weights_cache,
  3724. xnn_operator_t* fully_connected_op_out);
  3725. enum xnn_status xnn_setup_fully_connected_nc_qd8_f16_qc4w(
  3726. xnn_operator_t fully_connected_op,
  3727. const int8_t* input,
  3728. void* output,
  3729. const struct xnn_dynamic_quantization_params* quantization_params);
  3730. enum xnn_status xnn_reshape_fully_connected_nc_qd8_f16_qc4w(
  3731. xnn_operator_t fully_connected_op,
  3732. size_t batch_size,
  3733. pthreadpool_t threadpool);
  3734. enum xnn_status xnn_create_fully_connected_nc_qd8_f32_qc4w(
  3735. size_t input_channels,
  3736. size_t output_channels,
  3737. size_t input_stride,
  3738. size_t output_stride,
  3739. uint8_t kernel_zero_point,
  3740. const float* kernel_scale,
  3741. const void* kernel,
  3742. const float* bias,
  3743. float output_min,
  3744. float output_max,
  3745. uint32_t flags,
  3746. xnn_code_cache_t code_cache,
  3747. xnn_weights_cache_t weights_cache,
  3748. xnn_operator_t* fully_connected_op_out);
  3749. enum xnn_status xnn_setup_fully_connected_nc_qd8_f32_qc4w(
  3750. xnn_operator_t fully_connected_op,
  3751. const int8_t* input,
  3752. float* output,
  3753. const struct xnn_dynamic_quantization_params* quantization_params);
  3754. enum xnn_status xnn_reshape_fully_connected_nc_qd8_f32_qc4w(
  3755. xnn_operator_t fully_connected_op,
  3756. size_t batch_size,
  3757. pthreadpool_t threadpool);
  3758. enum xnn_status xnn_create_fully_connected_nc_qd8_f16_qc8w(
  3759. size_t input_channels,
  3760. size_t output_channels,
  3761. size_t input_stride,
  3762. size_t output_stride,
  3763. const float* kernel_scale,
  3764. const int8_t* kernel,
  3765. const float* bias,
  3766. float output_min,
  3767. float output_max,
  3768. uint32_t flags,
  3769. xnn_code_cache_t code_cache,
  3770. xnn_weights_cache_t weights_cache,
  3771. xnn_operator_t* fully_connected_op_out);
  3772. enum xnn_status xnn_setup_fully_connected_nc_qd8_f16_qc8w(
  3773. xnn_operator_t fully_connected_op,
  3774. const int8_t* input,
  3775. void* output,
  3776. const struct xnn_dynamic_quantization_params* quantization_params);
  3777. enum xnn_status xnn_reshape_fully_connected_nc_qd8_f16_qc8w(
  3778. xnn_operator_t fully_connected_op,
  3779. size_t batch_size,
  3780. pthreadpool_t threadpool);
  3781. enum xnn_status xnn_create_fully_connected_nc_qd8_f32_qc8w(
  3782. size_t input_channels,
  3783. size_t output_channels,
  3784. size_t input_stride,
  3785. size_t output_stride,
  3786. const float* kernel_scale,
  3787. const int8_t* kernel,
  3788. const float* bias,
  3789. float output_min,
  3790. float output_max,
  3791. uint32_t flags,
  3792. xnn_code_cache_t code_cache,
  3793. xnn_weights_cache_t weights_cache,
  3794. xnn_operator_t* fully_connected_op_out);
  3795. enum xnn_status xnn_setup_fully_connected_nc_qd8_f32_qc8w(
  3796. xnn_operator_t fully_connected_op,
  3797. const int8_t* input,
  3798. float* output,
  3799. const struct xnn_dynamic_quantization_params* quantization_params);
  3800. enum xnn_status xnn_reshape_fully_connected_nc_qd8_f32_qc8w(
  3801. xnn_operator_t fully_connected_op,
  3802. size_t batch_size,
  3803. pthreadpool_t threadpool);
  3804. enum xnn_status xnn_create_fully_connected_nc_qs8(
  3805. size_t input_channels,
  3806. size_t output_channels,
  3807. size_t input_stride,
  3808. size_t output_stride,
  3809. int8_t input_zero_point,
  3810. float input_scale,
  3811. float kernel_scale,
  3812. const int8_t* kernel,
  3813. const int32_t* bias,
  3814. int8_t output_zero_point,
  3815. float output_scale,
  3816. int8_t output_min,
  3817. int8_t output_max,
  3818. uint32_t flags,
  3819. xnn_code_cache_t code_cache,
  3820. xnn_weights_cache_t weights_cache,
  3821. xnn_operator_t* fully_connected_op_out);
  3822. enum xnn_status xnn_reshape_fully_connected_nc_qs8(
  3823. xnn_operator_t fully_connected_op,
  3824. size_t batch_size,
  3825. pthreadpool_t threadpool);
  3826. enum xnn_status xnn_setup_fully_connected_nc_qs8(
  3827. xnn_operator_t fully_connected_op,
  3828. const int8_t* input,
  3829. int8_t* output);
  3830. enum xnn_status xnn_create_fully_connected_nc_qs8_qc8w(
  3831. size_t input_channels,
  3832. size_t output_channels,
  3833. size_t input_stride,
  3834. size_t output_stride,
  3835. int8_t input_zero_point,
  3836. float input_scale,
  3837. const float* kernel_scale,
  3838. const int8_t* kernel,
  3839. const int32_t* bias,
  3840. int8_t output_zero_point,
  3841. float output_scale,
  3842. int8_t output_min,
  3843. int8_t output_max,
  3844. uint32_t flags,
  3845. xnn_code_cache_t code_cache,
  3846. xnn_weights_cache_t weights_cache,
  3847. xnn_operator_t* fully_connected_op_out);
  3848. enum xnn_status xnn_reshape_fully_connected_nc_qs8_qc8w(
  3849. xnn_operator_t fully_connected_op,
  3850. size_t batch_size,
  3851. pthreadpool_t threadpool);
  3852. enum xnn_status xnn_setup_fully_connected_nc_qs8_qc8w(
  3853. xnn_operator_t fully_connected_op,
  3854. const int8_t* input,
  3855. int8_t* output);
  3856. enum xnn_status xnn_create_fully_connected_nc_qu8(
  3857. size_t input_channels,
  3858. size_t output_channels,
  3859. size_t input_stride,
  3860. size_t output_stride,
  3861. uint8_t input_zero_point,
  3862. float input_scale,
  3863. uint8_t kernel_zero_point,
  3864. float kernel_scale,
  3865. const uint8_t* kernel,
  3866. const int32_t* bias,
  3867. uint8_t output_zero_point,
  3868. float output_scale,
  3869. uint8_t output_min,
  3870. uint8_t output_max,
  3871. uint32_t flags,
  3872. xnn_code_cache_t code_cache,
  3873. xnn_weights_cache_t weights_cache,
  3874. xnn_operator_t* fully_connected_op_out);
  3875. enum xnn_status xnn_reshape_fully_connected_nc_qu8(
  3876. xnn_operator_t fully_connected_op,
  3877. size_t batch_size,
  3878. pthreadpool_t threadpool);
  3879. enum xnn_status xnn_setup_fully_connected_nc_qu8(
  3880. xnn_operator_t fully_connected_op,
  3881. const uint8_t* input,
  3882. uint8_t* output);
  3883. enum xnn_status xnn_create_global_average_pooling_ncw_f16(
  3884. float output_min,
  3885. float output_max,
  3886. uint32_t flags,
  3887. xnn_operator_t* global_average_pooling_op_out);
  3888. enum xnn_status xnn_reshape_global_average_pooling_ncw_f16(
  3889. xnn_operator_t global_average_pooling_op,
  3890. size_t batch_size,
  3891. size_t width,
  3892. size_t channels,
  3893. pthreadpool_t threadpool);
  3894. enum xnn_status xnn_setup_global_average_pooling_ncw_f16(
  3895. xnn_operator_t global_average_pooling_op,
  3896. const void* input,
  3897. void* output);
  3898. enum xnn_status xnn_create_global_average_pooling_ncw_f32(
  3899. float output_min,
  3900. float output_max,
  3901. uint32_t flags,
  3902. xnn_operator_t* global_average_pooling_op_out);
  3903. enum xnn_status xnn_reshape_global_average_pooling_ncw_f32(
  3904. xnn_operator_t global_average_pooling_op,
  3905. size_t batch_size,
  3906. size_t width,
  3907. size_t channels,
  3908. pthreadpool_t threadpool);
  3909. enum xnn_status xnn_setup_global_average_pooling_ncw_f32(
  3910. xnn_operator_t global_average_pooling_op,
  3911. const float* input,
  3912. float* output);
  3913. enum xnn_status xnn_create_global_average_pooling_nwc_f16(
  3914. float output_min,
  3915. float output_max,
  3916. uint32_t flags,
  3917. xnn_operator_t* global_average_pooling_op_out);
  3918. enum xnn_status xnn_reshape_global_average_pooling_nwc_f16(
  3919. xnn_operator_t global_average_pooling_op,
  3920. size_t batch_size,
  3921. size_t width,
  3922. size_t channels,
  3923. size_t input_stride,
  3924. size_t output_stride,
  3925. size_t* workspace_size,
  3926. size_t* workspace_alignment,
  3927. pthreadpool_t threadpool);
  3928. enum xnn_status xnn_setup_global_average_pooling_nwc_f16(
  3929. xnn_operator_t global_average_pooling_op,
  3930. void* workspace,
  3931. const void* input,
  3932. void* output);
  3933. enum xnn_status xnn_create_global_average_pooling_nwc_f32(
  3934. float output_min,
  3935. float output_max,
  3936. uint32_t flags,
  3937. xnn_operator_t* global_average_pooling_op_out);
  3938. enum xnn_status xnn_reshape_global_average_pooling_nwc_f32(
  3939. xnn_operator_t global_average_pooling_op,
  3940. size_t batch_size,
  3941. size_t width,
  3942. size_t channels,
  3943. size_t input_stride,
  3944. size_t output_stride,
  3945. size_t* workspace_size,
  3946. size_t* workspace_alignment,
  3947. pthreadpool_t threadpool);
  3948. enum xnn_status xnn_setup_global_average_pooling_nwc_f32(
  3949. xnn_operator_t global_average_pooling_op,
  3950. void* workspace,
  3951. const float* input,
  3952. float* output);
  3953. enum xnn_status xnn_create_global_average_pooling_nwc_qs8(
  3954. int8_t input_zero_point,
  3955. float input_scale,
  3956. int8_t output_zero_point,
  3957. float output_scale,
  3958. int8_t output_min,
  3959. int8_t output_max,
  3960. uint32_t flags,
  3961. xnn_operator_t* global_average_pooling_op_out);
  3962. enum xnn_status xnn_reshape_global_average_pooling_nwc_qs8(
  3963. xnn_operator_t global_average_pooling_op,
  3964. size_t batch_size,
  3965. size_t width,
  3966. size_t channels,
  3967. size_t input_stride,
  3968. size_t output_stride,
  3969. size_t* workspace_size,
  3970. size_t* workspace_alignment,
  3971. pthreadpool_t threadpool);
  3972. enum xnn_status xnn_setup_global_average_pooling_nwc_qs8(
  3973. xnn_operator_t global_average_pooling_op,
  3974. void* workspace,
  3975. const int8_t* input,
  3976. int8_t* output);
  3977. enum xnn_status xnn_create_global_average_pooling_nwc_qu8(
  3978. uint8_t input_zero_point,
  3979. float input_scale,
  3980. uint8_t output_zero_point,
  3981. float output_scale,
  3982. uint8_t output_min,
  3983. uint8_t output_max,
  3984. uint32_t flags,
  3985. xnn_operator_t* global_average_pooling_op_out);
  3986. enum xnn_status xnn_reshape_global_average_pooling_nwc_qu8(
  3987. xnn_operator_t global_average_pooling_op,
  3988. size_t batch_size,
  3989. size_t width,
  3990. size_t channels,
  3991. size_t input_stride,
  3992. size_t output_stride,
  3993. size_t* workspace_size,
  3994. size_t* workspace_alignment,
  3995. pthreadpool_t threadpool);
  3996. enum xnn_status xnn_setup_global_average_pooling_nwc_qu8(
  3997. xnn_operator_t global_average_pooling_op,
  3998. void* workspace,
  3999. const uint8_t* input,
  4000. uint8_t* output);
  4001. enum xnn_status xnn_create_global_sum_pooling_nwc_f16(
  4002. float output_min,
  4003. float output_max,
  4004. uint32_t flags,
  4005. xnn_operator_t* global_sum_pooling_op_out);
  4006. enum xnn_status xnn_reshape_global_sum_pooling_nwc_f16(
  4007. xnn_operator_t global_sum_pooling_op,
  4008. size_t batch_size,
  4009. size_t width,
  4010. size_t channels,
  4011. size_t input_stride,
  4012. size_t output_stride,
  4013. size_t* workspace_size,
  4014. size_t* workspace_alignment,
  4015. pthreadpool_t threadpool);
  4016. enum xnn_status xnn_setup_global_sum_pooling_nwc_f16(
  4017. xnn_operator_t global_sum_pooling_op,
  4018. void* workspace,
  4019. const void* input,
  4020. void* output);
  4021. enum xnn_status xnn_create_global_sum_pooling_nwc_f32(
  4022. float output_min,
  4023. float output_max,
  4024. uint32_t flags,
  4025. xnn_operator_t* global_sum_pooling_op_out);
  4026. enum xnn_status xnn_reshape_global_sum_pooling_nwc_f32(
  4027. xnn_operator_t global_sum_pooling_op,
  4028. size_t batch_size,
  4029. size_t width,
  4030. size_t channels,
  4031. size_t input_stride,
  4032. size_t output_stride,
  4033. size_t* workspace_size,
  4034. size_t* workspace_alignment,
  4035. pthreadpool_t threadpool);
  4036. enum xnn_status xnn_setup_global_sum_pooling_nwc_f32(
  4037. xnn_operator_t global_sum_pooling_op,
  4038. void* workspace,
  4039. const float* input,
  4040. float* output);
  4041. enum xnn_status xnn_create_hardswish_nc_f16(
  4042. uint32_t flags,
  4043. xnn_operator_t* hardswish_op_out);
  4044. enum xnn_status xnn_reshape_hardswish_nc_f16(
  4045. xnn_operator_t hardswish_op,
  4046. size_t batch_size,
  4047. size_t channels,
  4048. size_t input_stride,
  4049. size_t output_stride,
  4050. pthreadpool_t threadpool);
  4051. enum xnn_status xnn_setup_hardswish_nc_f16(
  4052. xnn_operator_t hardswish_op,
  4053. const void* input,
  4054. void* output);
  4055. enum xnn_status xnn_create_hardswish_nc_f32(
  4056. uint32_t flags,
  4057. xnn_operator_t* hardswish_op_out);
  4058. enum xnn_status xnn_reshape_hardswish_nc_f32(
  4059. xnn_operator_t hardswish_op,
  4060. size_t batch_size,
  4061. size_t channels,
  4062. size_t input_stride,
  4063. size_t output_stride,
  4064. pthreadpool_t threadpool);
  4065. enum xnn_status xnn_setup_hardswish_nc_f32(
  4066. xnn_operator_t hardswish_op,
  4067. const float* input,
  4068. float* output);
  4069. enum xnn_status xnn_run_hardswish_nc_f32(
  4070. size_t channels,
  4071. size_t input_stride,
  4072. size_t output_stride,
  4073. size_t batch_size,
  4074. const float* input,
  4075. float* output,
  4076. uint32_t flags,
  4077. pthreadpool_t threadpool);
  4078. enum xnn_status xnn_create_leaky_relu_nc_f16(
  4079. float negative_slope,
  4080. uint32_t flags,
  4081. xnn_operator_t* leaky_relu_op_out);
  4082. enum xnn_status xnn_reshape_leaky_relu_nc_f16(
  4083. xnn_operator_t leaky_relu_op,
  4084. size_t batch_size,
  4085. size_t channels,
  4086. size_t input_stride,
  4087. size_t output_stride,
  4088. pthreadpool_t threadpool);
  4089. enum xnn_status xnn_setup_leaky_relu_nc_f16(
  4090. xnn_operator_t leaky_relu_op,
  4091. const void* input,
  4092. void* output);
  4093. enum xnn_status xnn_create_leaky_relu_nc_f32(
  4094. float negative_slope,
  4095. uint32_t flags,
  4096. xnn_operator_t* leaky_relu_op_out);
  4097. enum xnn_status xnn_reshape_leaky_relu_nc_f32(
  4098. xnn_operator_t leaky_relu_op,
  4099. size_t batch_size,
  4100. size_t channels,
  4101. size_t input_stride,
  4102. size_t output_stride,
  4103. pthreadpool_t threadpool);
  4104. enum xnn_status xnn_setup_leaky_relu_nc_f32(
  4105. xnn_operator_t leaky_relu_op,
  4106. const float* input,
  4107. float* output);
  4108. enum xnn_status xnn_run_leaky_relu_nc_f32(
  4109. size_t channels,
  4110. size_t input_stride,
  4111. size_t output_stride,
  4112. size_t batch_size,
  4113. const float* input,
  4114. float* output,
  4115. float negative_slope,
  4116. uint32_t flags,
  4117. pthreadpool_t threadpool);
  4118. enum xnn_status xnn_create_leaky_relu_nc_qs8(
  4119. float negative_slope,
  4120. int8_t input_zero_point,
  4121. float input_scale,
  4122. int8_t output_zero_point,
  4123. float output_scale,
  4124. uint32_t flags,
  4125. xnn_operator_t* leaky_relu_op_out);
  4126. enum xnn_status xnn_reshape_leaky_relu_nc_qs8(
  4127. xnn_operator_t leaky_relu_op,
  4128. size_t batch_size,
  4129. size_t channels,
  4130. size_t input_stride,
  4131. size_t output_stride,
  4132. pthreadpool_t threadpool);
  4133. enum xnn_status xnn_setup_leaky_relu_nc_qs8(
  4134. xnn_operator_t leaky_relu_op,
  4135. const int8_t* input,
  4136. int8_t* output);
  4137. enum xnn_status xnn_create_leaky_relu_nc_qu8(
  4138. float negative_slope,
  4139. uint8_t input_zero_point,
  4140. float input_scale,
  4141. uint8_t output_zero_point,
  4142. float output_scale,
  4143. uint32_t flags,
  4144. xnn_operator_t* leaky_relu_op_out);
  4145. enum xnn_status xnn_reshape_leaky_relu_nc_qu8(
  4146. xnn_operator_t leaky_relu_op,
  4147. size_t batch_size,
  4148. size_t channels,
  4149. size_t input_stride,
  4150. size_t output_stride,
  4151. pthreadpool_t threadpool);
  4152. enum xnn_status xnn_setup_leaky_relu_nc_qu8(
  4153. xnn_operator_t leaky_relu_op,
  4154. const uint8_t* input,
  4155. uint8_t* output);
  4156. enum xnn_status xnn_create_max_pooling2d_nhwc_f16(
  4157. uint32_t input_padding_top,
  4158. uint32_t input_padding_right,
  4159. uint32_t input_padding_bottom,
  4160. uint32_t input_padding_left,
  4161. uint32_t pooling_height,
  4162. uint32_t pooling_width,
  4163. uint32_t stride_height,
  4164. uint32_t stride_width,
  4165. uint32_t dilation_height,
  4166. uint32_t dilation_width,
  4167. float output_min,
  4168. float output_max,
  4169. uint32_t flags,
  4170. xnn_operator_t* max_pooling_op_out);
  4171. enum xnn_status xnn_reshape_max_pooling2d_nhwc_f16(
  4172. xnn_operator_t max_pooling_op,
  4173. size_t batch_size,
  4174. size_t input_height,
  4175. size_t input_width,
  4176. size_t channels,
  4177. size_t input_pixel_stride,
  4178. size_t output_pixel_stride,
  4179. size_t* output_height_out,
  4180. size_t* output_width_out,
  4181. pthreadpool_t threadpool);
  4182. enum xnn_status xnn_setup_max_pooling2d_nhwc_f16(
  4183. xnn_operator_t max_pooling_op,
  4184. const void* input,
  4185. void* output);
  4186. enum xnn_status xnn_create_max_pooling2d_nhwc_f32(
  4187. uint32_t input_padding_top,
  4188. uint32_t input_padding_right,
  4189. uint32_t input_padding_bottom,
  4190. uint32_t input_padding_left,
  4191. uint32_t pooling_height,
  4192. uint32_t pooling_width,
  4193. uint32_t stride_height,
  4194. uint32_t stride_width,
  4195. uint32_t dilation_height,
  4196. uint32_t dilation_width,
  4197. float output_min,
  4198. float output_max,
  4199. uint32_t flags,
  4200. xnn_operator_t* max_pooling_op_out);
  4201. enum xnn_status xnn_reshape_max_pooling2d_nhwc_f32(
  4202. xnn_operator_t max_pooling_op,
  4203. size_t batch_size,
  4204. size_t input_height,
  4205. size_t input_width,
  4206. size_t channels,
  4207. size_t input_pixel_stride,
  4208. size_t output_pixel_stride,
  4209. size_t* output_height_out,
  4210. size_t* output_width_out,
  4211. pthreadpool_t threadpool);
  4212. enum xnn_status xnn_setup_max_pooling2d_nhwc_f32(
  4213. xnn_operator_t max_pooling_op,
  4214. const float* input,
  4215. float* output);
  4216. enum xnn_status xnn_create_max_pooling2d_nhwc_s8(
  4217. uint32_t input_padding_top,
  4218. uint32_t input_padding_right,
  4219. uint32_t input_padding_bottom,
  4220. uint32_t input_padding_left,
  4221. uint32_t pooling_height,
  4222. uint32_t pooling_width,
  4223. uint32_t stride_height,
  4224. uint32_t stride_width,
  4225. uint32_t dilation_height,
  4226. uint32_t dilation_width,
  4227. int8_t output_min,
  4228. int8_t output_max,
  4229. uint32_t flags,
  4230. xnn_operator_t* max_pooling_op_out);
  4231. enum xnn_status xnn_reshape_max_pooling2d_nhwc_s8(
  4232. xnn_operator_t max_pooling_op,
  4233. size_t batch_size,
  4234. size_t input_height,
  4235. size_t input_width,
  4236. size_t channels,
  4237. size_t input_pixel_stride,
  4238. size_t output_pixel_stride,
  4239. size_t* output_height_out,
  4240. size_t* output_width_out,
  4241. pthreadpool_t threadpool);
  4242. enum xnn_status xnn_setup_max_pooling2d_nhwc_s8(
  4243. xnn_operator_t max_pooling_op,
  4244. const int8_t* input,
  4245. int8_t* output);
  4246. enum xnn_status xnn_create_max_pooling2d_nhwc_u8(
  4247. uint32_t input_padding_top,
  4248. uint32_t input_padding_right,
  4249. uint32_t input_padding_bottom,
  4250. uint32_t input_padding_left,
  4251. uint32_t pooling_height,
  4252. uint32_t pooling_width,
  4253. uint32_t stride_height,
  4254. uint32_t stride_width,
  4255. uint32_t dilation_height,
  4256. uint32_t dilation_width,
  4257. uint8_t output_min,
  4258. uint8_t output_max,
  4259. uint32_t flags,
  4260. xnn_operator_t* max_pooling_op_out);
  4261. enum xnn_status xnn_reshape_max_pooling2d_nhwc_u8(
  4262. xnn_operator_t max_pooling_op,
  4263. size_t batch_size,
  4264. size_t input_height,
  4265. size_t input_width,
  4266. size_t channels,
  4267. size_t input_pixel_stride,
  4268. size_t output_pixel_stride,
  4269. size_t* output_height_out,
  4270. size_t* output_width_out,
  4271. pthreadpool_t threadpool);
  4272. enum xnn_status xnn_setup_max_pooling2d_nhwc_u8(
  4273. xnn_operator_t max_pooling_op,
  4274. const uint8_t* input,
  4275. uint8_t* output);
  4276. enum xnn_status xnn_create_maximum_nd_f16(
  4277. uint32_t flags,
  4278. xnn_operator_t* maximum_op_out);
  4279. enum xnn_status xnn_reshape_maximum_nd_f16(
  4280. xnn_operator_t maximum_op,
  4281. size_t num_input1_dims,
  4282. const size_t* input1_shape,
  4283. size_t num_input2_dims,
  4284. const size_t* input2_shape,
  4285. pthreadpool_t threadpool);
  4286. enum xnn_status xnn_setup_maximum_nd_f16(
  4287. xnn_operator_t maximum_op,
  4288. const void* input1,
  4289. const void* input2,
  4290. void* output);
  4291. enum xnn_status xnn_create_maximum_nd_f32(
  4292. uint32_t flags,
  4293. xnn_operator_t* maximum_op_out);
  4294. enum xnn_status xnn_reshape_maximum_nd_f32(
  4295. xnn_operator_t maximum_op,
  4296. size_t num_input1_dims,
  4297. const size_t* input1_shape,
  4298. size_t num_input2_dims,
  4299. const size_t* input2_shape,
  4300. pthreadpool_t threadpool);
  4301. enum xnn_status xnn_setup_maximum_nd_f32(
  4302. xnn_operator_t maximum_op,
  4303. const float* input1,
  4304. const float* input2,
  4305. float* output);
  4306. enum xnn_status xnn_run_maximum_nd_f32(
  4307. size_t num_input1_dims,
  4308. const size_t* input1_shape,
  4309. size_t num_input2_dims,
  4310. const size_t* input2_shape,
  4311. const float* input1,
  4312. const float* input2,
  4313. float* output,
  4314. uint32_t flags,
  4315. pthreadpool_t threadpool);
  4316. enum xnn_status xnn_create_mean_nd_f16(
  4317. uint32_t flags,
  4318. xnn_operator_t* mean_op_out);
  4319. enum xnn_status xnn_reshape_mean_nd_f16(
  4320. xnn_operator_t mean_op,
  4321. size_t num_reduction_axes,
  4322. const size_t* reduction_axes,
  4323. size_t num_input_dims,
  4324. const size_t* input_shape,
  4325. size_t* workspace_size,
  4326. size_t* workspace_alignment,
  4327. pthreadpool_t threadpool);
  4328. enum xnn_status xnn_setup_mean_nd_f16(
  4329. xnn_operator_t mean_op,
  4330. void* workspace,
  4331. const void* input,
  4332. void* output);
  4333. enum xnn_status xnn_create_mean_nd_f32(
  4334. uint32_t flags,
  4335. xnn_operator_t* mean_op_out);
  4336. enum xnn_status xnn_reshape_mean_nd_f32(
  4337. xnn_operator_t mean_op,
  4338. size_t num_reduction_axes,
  4339. const size_t* reduction_axes,
  4340. size_t num_input_dims,
  4341. const size_t* input_shape,
  4342. size_t* workspace_size,
  4343. size_t* workspace_alignment,
  4344. pthreadpool_t threadpool);
  4345. enum xnn_status xnn_setup_mean_nd_f32(
  4346. xnn_operator_t mean_op,
  4347. void* workspace,
  4348. const float* input,
  4349. float* output);
  4350. enum xnn_status xnn_create_minimum_nd_f16(
  4351. uint32_t flags,
  4352. xnn_operator_t* minimum_op_out);
  4353. enum xnn_status xnn_reshape_minimum_nd_f16(
  4354. xnn_operator_t minimum_op,
  4355. size_t num_input1_dims,
  4356. const size_t* input1_shape,
  4357. size_t num_input2_dims,
  4358. const size_t* input2_shape,
  4359. pthreadpool_t threadpool);
  4360. enum xnn_status xnn_setup_minimum_nd_f16(
  4361. xnn_operator_t minimum_op,
  4362. const void* input1,
  4363. const void* input2,
  4364. void* output);
  4365. enum xnn_status xnn_create_minimum_nd_f32(
  4366. uint32_t flags,
  4367. xnn_operator_t* minimum_op_out);
  4368. enum xnn_status xnn_reshape_minimum_nd_f32(
  4369. xnn_operator_t minimum_op,
  4370. size_t num_input1_dims,
  4371. const size_t* input1_shape,
  4372. size_t num_input2_dims,
  4373. const size_t* input2_shape,
  4374. pthreadpool_t threadpool);
  4375. enum xnn_status xnn_setup_minimum_nd_f32(
  4376. xnn_operator_t minimum_op,
  4377. const float* input1,
  4378. const float* input2,
  4379. float* output);
  4380. enum xnn_status xnn_run_minimum_nd_f32(
  4381. size_t num_input1_dims,
  4382. const size_t* input1_shape,
  4383. size_t num_input2_dims,
  4384. const size_t* input2_shape,
  4385. const float* input1,
  4386. const float* input2,
  4387. float* output,
  4388. uint32_t flags,
  4389. pthreadpool_t threadpool);
  4390. enum xnn_status xnn_create_multiply_nd_f16(
  4391. float output_min,
  4392. float output_max,
  4393. uint32_t flags,
  4394. xnn_operator_t* multiply_op_out);
  4395. enum xnn_status xnn_reshape_multiply_nd_f16(
  4396. xnn_operator_t multiply_op,
  4397. size_t num_input1_dims,
  4398. const size_t* input1_shape,
  4399. size_t num_input2_dims,
  4400. const size_t* input2_shape,
  4401. pthreadpool_t threadpool);
  4402. enum xnn_status xnn_setup_multiply_nd_f16(
  4403. xnn_operator_t multiply_op,
  4404. const void* input1,
  4405. const void* input2,
  4406. void* output);
  4407. enum xnn_status xnn_create_multiply_nd_f32(
  4408. float output_min,
  4409. float output_max,
  4410. uint32_t flags,
  4411. xnn_operator_t* multiply_op_out);
  4412. enum xnn_status xnn_reshape_multiply_nd_f32(
  4413. xnn_operator_t multiply_op,
  4414. size_t num_input1_dims,
  4415. const size_t* input1_shape,
  4416. size_t num_input2_dims,
  4417. const size_t* input2_shape,
  4418. pthreadpool_t threadpool);
  4419. enum xnn_status xnn_setup_multiply_nd_f32(
  4420. xnn_operator_t multiply_op,
  4421. const float* input1,
  4422. const float* input2,
  4423. float* output);
  4424. enum xnn_status xnn_run_multiply_nd_f32(
  4425. size_t num_input1_dims,
  4426. const size_t* input1_shape,
  4427. size_t num_input2_dims,
  4428. const size_t* input2_shape,
  4429. const float* input1,
  4430. const float* input2,
  4431. float* output,
  4432. float output_min,
  4433. float output_max,
  4434. uint32_t flags,
  4435. pthreadpool_t threadpool);
  4436. enum xnn_status xnn_create_multiply_nd_qs8(
  4437. int8_t input1_zero_point,
  4438. float input1_scale,
  4439. int8_t input2_zero_point,
  4440. float input2_scale,
  4441. int8_t output_zero_point,
  4442. float output_scale,
  4443. int8_t output_min,
  4444. int8_t output_max,
  4445. uint32_t flags,
  4446. xnn_operator_t* multiply_op_out);
  4447. enum xnn_status xnn_reshape_multiply_nd_qs8(
  4448. xnn_operator_t multiply_op,
  4449. size_t num_input1_dims,
  4450. const size_t* input1_shape,
  4451. size_t num_input2_dims,
  4452. const size_t* input2_shape,
  4453. pthreadpool_t threadpool);
  4454. enum xnn_status xnn_setup_multiply_nd_qs8(
  4455. xnn_operator_t multiply_op,
  4456. const int8_t* input1,
  4457. const int8_t* input2,
  4458. int8_t* output);
  4459. enum xnn_status xnn_run_multiply_nd_qs8(
  4460. size_t num_input1_dims,
  4461. const size_t* input1_shape,
  4462. int8_t input1_zero_point,
  4463. float input1_scale,
  4464. size_t num_input2_dims,
  4465. const size_t* input2_shape,
  4466. int8_t input2_zero_point,
  4467. float input2_scale,
  4468. const int8_t* input1,
  4469. const int8_t* input2,
  4470. int8_t* output,
  4471. int8_t output_zero_point,
  4472. float output_scale,
  4473. int8_t output_min,
  4474. int8_t output_max,
  4475. uint32_t flags,
  4476. pthreadpool_t threadpool);
  4477. enum xnn_status xnn_create_multiply_nd_qu8(
  4478. uint8_t input1_zero_point,
  4479. float input1_scale,
  4480. uint8_t input2_zero_point,
  4481. float input2_scale,
  4482. uint8_t output_zero_point,
  4483. float output_scale,
  4484. uint8_t output_min,
  4485. uint8_t output_max,
  4486. uint32_t flags,
  4487. xnn_operator_t* multiply_op_out);
  4488. enum xnn_status xnn_reshape_multiply_nd_qu8(
  4489. xnn_operator_t multiply_op,
  4490. size_t num_input1_dims,
  4491. const size_t* input1_shape,
  4492. size_t num_input2_dims,
  4493. const size_t* input2_shape,
  4494. pthreadpool_t threadpool);
  4495. enum xnn_status xnn_setup_multiply_nd_qu8(
  4496. xnn_operator_t multiply_op,
  4497. const uint8_t* input1,
  4498. const uint8_t* input2,
  4499. uint8_t* output);
  4500. enum xnn_status xnn_run_multiply_nd_qu8(
  4501. size_t num_input1_dims,
  4502. const size_t* input1_shape,
  4503. uint8_t input1_zero_point,
  4504. float input1_scale,
  4505. size_t num_input2_dims,
  4506. const size_t* input2_shape,
  4507. uint8_t input2_zero_point,
  4508. float input2_scale,
  4509. const uint8_t* input1,
  4510. const uint8_t* input2,
  4511. uint8_t* output,
  4512. uint8_t output_zero_point,
  4513. float output_scale,
  4514. uint8_t output_min,
  4515. uint8_t output_max,
  4516. uint32_t flags,
  4517. pthreadpool_t threadpool);
  4518. enum xnn_status xnn_create_negate_nc_f16(
  4519. uint32_t flags,
  4520. xnn_operator_t* negate_op_out);
  4521. enum xnn_status xnn_reshape_negate_nc_f16(
  4522. xnn_operator_t negate_op,
  4523. size_t batch_size,
  4524. size_t channels,
  4525. size_t input_stride,
  4526. size_t output_stride,
  4527. pthreadpool_t threadpool);
  4528. enum xnn_status xnn_setup_negate_nc_f16(
  4529. xnn_operator_t negate_op,
  4530. const void* input,
  4531. void* output);
  4532. enum xnn_status xnn_create_negate_nc_f32(
  4533. uint32_t flags,
  4534. xnn_operator_t* negate_op_out);
  4535. enum xnn_status xnn_reshape_negate_nc_f32(
  4536. xnn_operator_t negate_op,
  4537. size_t batch_size,
  4538. size_t channels,
  4539. size_t input_stride,
  4540. size_t output_stride,
  4541. pthreadpool_t threadpool);
  4542. enum xnn_status xnn_setup_negate_nc_f32(
  4543. xnn_operator_t negate_op,
  4544. const float* input,
  4545. float* output);
  4546. enum xnn_status xnn_run_negate_nc_f32(
  4547. size_t channels,
  4548. size_t input_stride,
  4549. size_t output_stride,
  4550. size_t batch_size,
  4551. const float* input,
  4552. float* output,
  4553. uint32_t flags,
  4554. pthreadpool_t threadpool);
  4555. enum xnn_status xnn_create_prelu_nc_f16(
  4556. size_t channels,
  4557. size_t input_stride,
  4558. size_t output_stride,
  4559. const void* negative_slope,
  4560. uint32_t flags,
  4561. xnn_code_cache_t code_cache,
  4562. xnn_weights_cache_t weights_cache,
  4563. xnn_operator_t* prelu_op_out);
  4564. enum xnn_status xnn_reshape_prelu_nc_f16(
  4565. xnn_operator_t prelu_op,
  4566. size_t batch_size,
  4567. pthreadpool_t threadpool);
  4568. enum xnn_status xnn_setup_prelu_nc_f16(
  4569. xnn_operator_t prelu_op,
  4570. const void* input,
  4571. void* output);
  4572. enum xnn_status xnn_create_prelu_nc_f32(
  4573. size_t channels,
  4574. size_t input_stride,
  4575. size_t output_stride,
  4576. const float* negative_slope,
  4577. uint32_t flags,
  4578. xnn_code_cache_t code_cache,
  4579. xnn_weights_cache_t weights_cache,
  4580. xnn_operator_t* prelu_op_out);
  4581. enum xnn_status xnn_reshape_prelu_nc_f32(
  4582. xnn_operator_t prelu_op,
  4583. size_t batch_size,
  4584. pthreadpool_t threadpool);
  4585. enum xnn_status xnn_setup_prelu_nc_f32(
  4586. xnn_operator_t prelu_op,
  4587. const float* input,
  4588. float* output);
  4589. enum xnn_status xnn_create_resize_bilinear2d_nchw_f32(
  4590. size_t output_height,
  4591. size_t output_width,
  4592. uint32_t flags,
  4593. xnn_operator_t* resize_op_out);
  4594. enum xnn_status xnn_reshape_resize_bilinear2d_nchw_f32(
  4595. xnn_operator_t resize_op,
  4596. size_t batch_size,
  4597. size_t input_height,
  4598. size_t input_width,
  4599. size_t channels,
  4600. size_t input_pixel_stride,
  4601. size_t output_pixel_stride,
  4602. pthreadpool_t threadpool);
  4603. enum xnn_status xnn_setup_resize_bilinear2d_nchw_f32(
  4604. xnn_operator_t resize_op,
  4605. const float* input,
  4606. float* output);
  4607. enum xnn_status xnn_create_resize_bilinear2d_nchw_f16(
  4608. size_t output_height,
  4609. size_t output_width,
  4610. uint32_t flags,
  4611. xnn_operator_t* resize_op_out);
  4612. enum xnn_status xnn_reshape_resize_bilinear2d_nchw_f16(
  4613. xnn_operator_t resize_op,
  4614. size_t batch_size,
  4615. size_t input_height,
  4616. size_t input_width,
  4617. size_t channels,
  4618. size_t input_pixel_stride,
  4619. size_t output_pixel_stride,
  4620. pthreadpool_t threadpool);
  4621. enum xnn_status xnn_setup_resize_bilinear2d_nchw_f16(
  4622. xnn_operator_t resize_op,
  4623. const void* input,
  4624. void* output);
  4625. enum xnn_status xnn_create_resize_bilinear2d_nhwc_f16(
  4626. size_t output_height,
  4627. size_t output_width,
  4628. uint32_t flags,
  4629. xnn_operator_t* resize_op_out);
  4630. enum xnn_status xnn_reshape_resize_bilinear2d_nhwc_f16(
  4631. xnn_operator_t resize_op,
  4632. size_t batch_size,
  4633. size_t input_height,
  4634. size_t input_width,
  4635. size_t channels,
  4636. size_t input_pixel_stride,
  4637. size_t output_pixel_stride,
  4638. size_t* workspace_size,
  4639. size_t* workspace_alignment,
  4640. pthreadpool_t threadpool);
  4641. enum xnn_status xnn_setup_resize_bilinear2d_nhwc_f16(
  4642. xnn_operator_t resize_op,
  4643. void* workspace,
  4644. const void* input,
  4645. void* output);
  4646. enum xnn_status xnn_create_resize_bilinear2d_nhwc_f32(
  4647. size_t output_height,
  4648. size_t output_width,
  4649. uint32_t flags,
  4650. xnn_operator_t* resize_op_out);
  4651. enum xnn_status xnn_reshape_resize_bilinear2d_nhwc_f32(
  4652. xnn_operator_t resize_op,
  4653. size_t batch_size,
  4654. size_t input_height,
  4655. size_t input_width,
  4656. size_t channels,
  4657. size_t input_pixel_stride,
  4658. size_t output_pixel_stride,
  4659. size_t* workspace_size,
  4660. size_t* workspace_alignment,
  4661. pthreadpool_t threadpool);
  4662. enum xnn_status xnn_setup_resize_bilinear2d_nhwc_f32(
  4663. xnn_operator_t resize_op,
  4664. void* workspace,
  4665. const float* input,
  4666. float* output);
  4667. enum xnn_status xnn_create_resize_bilinear2d_nhwc_s8(
  4668. size_t output_height,
  4669. size_t output_width,
  4670. uint32_t flags,
  4671. xnn_operator_t* resize_op_out);
  4672. enum xnn_status xnn_reshape_resize_bilinear2d_nhwc_s8(
  4673. xnn_operator_t resize_op,
  4674. size_t batch_size,
  4675. size_t input_height,
  4676. size_t input_width,
  4677. size_t channels,
  4678. size_t input_pixel_stride,
  4679. size_t output_pixel_stride,
  4680. size_t* workspace_size,
  4681. size_t* workspace,
  4682. pthreadpool_t threadpool);
  4683. enum xnn_status xnn_setup_resize_bilinear2d_nhwc_s8(
  4684. xnn_operator_t resize_op,
  4685. void* workspace,
  4686. const int8_t* input,
  4687. int8_t* output);
  4688. enum xnn_status xnn_create_resize_bilinear2d_nhwc_u8(
  4689. size_t output_height,
  4690. size_t output_width,
  4691. uint32_t flags,
  4692. xnn_operator_t* resize_op_out);
  4693. enum xnn_status xnn_reshape_resize_bilinear2d_nhwc_u8(
  4694. xnn_operator_t resize_op,
  4695. size_t batch_size,
  4696. size_t input_height,
  4697. size_t input_width,
  4698. size_t channels,
  4699. size_t input_pixel_stride,
  4700. size_t output_pixel_stride,
  4701. size_t* workspace_size,
  4702. size_t* workspace_alignment,
  4703. pthreadpool_t threadpool);
  4704. enum xnn_status xnn_setup_resize_bilinear2d_nhwc_u8(
  4705. xnn_operator_t resize_op,
  4706. void* workspace,
  4707. const uint8_t* input,
  4708. uint8_t* output);
  4709. enum xnn_status xnn_create_rope_nthc_f16(
  4710. size_t max_tokens,
  4711. uint32_t flags,
  4712. xnn_operator_t* rope_op_out);
  4713. enum xnn_status xnn_reshape_rope_nthc_f16(
  4714. xnn_operator_t rope_op,
  4715. size_t batch_size,
  4716. size_t tokens,
  4717. size_t heads,
  4718. size_t channels,
  4719. pthreadpool_t threadpool);
  4720. enum xnn_status xnn_setup_rope_nthc_f16(
  4721. xnn_operator_t rope_op,
  4722. const void* input,
  4723. const void* weights,
  4724. void* output);
  4725. enum xnn_status xnn_create_rope_nthc_f32(
  4726. size_t max_tokens,
  4727. uint32_t flags,
  4728. xnn_operator_t* rope_op_out);
  4729. enum xnn_status xnn_reshape_rope_nthc_f32(
  4730. xnn_operator_t rope_op,
  4731. size_t batch_size,
  4732. size_t tokens,
  4733. size_t heads,
  4734. size_t channels,
  4735. pthreadpool_t threadpool);
  4736. enum xnn_status xnn_setup_rope_nthc_f32(
  4737. xnn_operator_t rope_op,
  4738. const float* input,
  4739. const float* weights,
  4740. float* output);
  4741. // N: batch size
  4742. // H: number of heads
  4743. // T: tokens (sequence length)
  4744. // C: channels (head dimension)
  4745. enum xnn_status xnn_create_scaled_dot_product_attention_nhtc_f16(
  4746. enum xnn_attention_logits_cap_type cap_type,
  4747. const void* cap_params,
  4748. uint32_t flags,
  4749. xnn_operator_t* attention_op_out);
  4750. enum xnn_status xnn_reshape_scaled_dot_product_attention_nhtc_f16(
  4751. xnn_operator_t attention_op,
  4752. size_t batch_size,
  4753. size_t query_heads,
  4754. // Number of tokens in query.
  4755. size_t query_tokens,
  4756. size_t key_value_heads,
  4757. // Number of tokens in key/value. For self-attention, this is same as tokens.
  4758. size_t key_value_tokens,
  4759. size_t query_key_channels,
  4760. size_t value_channels,
  4761. size_t* workspace_size,
  4762. size_t* workspace_alignment,
  4763. pthreadpool_t threadpool);
  4764. // Query is of dimension [batch_size, query_heads, query_tokens, channels].
  4765. // Key and value are of dimension [batch_size, key_value_heads, key_value_tokens, channels].
  4766. // Scale is of dimension [channels].
  4767. // Mask is of dimension [query_tokens, key_value_tokens].
  4768. enum xnn_status xnn_setup_scaled_dot_product_attention_nhtc_f16(
  4769. xnn_operator_t attention_op,
  4770. void* workspace,
  4771. const void* query,
  4772. const void* key,
  4773. const void* value,
  4774. const void* scale,
  4775. const void* mask,
  4776. void* output);
  4777. // N: batch size
  4778. // H: number of heads
  4779. // T: tokens (sequence length)
  4780. // C: channels (head dimension)
  4781. enum xnn_status xnn_create_scaled_dot_product_attention_nhtc_f32(
  4782. enum xnn_attention_logits_cap_type cap_type,
  4783. const void* cap_params,
  4784. uint32_t flags,
  4785. xnn_operator_t* attention_op_out);
  4786. enum xnn_status xnn_reshape_scaled_dot_product_attention_nhtc_f32(
  4787. xnn_operator_t attention_op,
  4788. size_t batch_size,
  4789. size_t query_heads,
  4790. // Number of tokens in query.
  4791. size_t query_tokens,
  4792. size_t key_value_heads,
  4793. // Number of tokens in key/value. For self-attention, this is same as tokens.
  4794. size_t key_value_tokens,
  4795. size_t query_key_channels,
  4796. size_t value_channels,
  4797. size_t* workspace_size,
  4798. size_t* workspace_alignment,
  4799. pthreadpool_t threadpool);
  4800. // Query is of dimension [batch_size, query_heads, query_tokens, query_key_channels].
  4801. // Key and value are of dimension [batch_size, key_value_heads, key_value_tokens, query_key_channels].
  4802. // Scale is of dimension [query_key_channels].
  4803. // Mask is of dimension [query_tokens, key_value_tokens].
  4804. // Output is of dimension [batch_size, query_heads, query_tokens, value_channels].
  4805. enum xnn_status xnn_setup_scaled_dot_product_attention_nhtc_f32(
  4806. xnn_operator_t attention_op,
  4807. void* workspace,
  4808. const float* query,
  4809. const float* key,
  4810. const float* value,
  4811. const float* scale,
  4812. const float* mask,
  4813. float* output);
  4814. enum xnn_status xnn_create_sigmoid_nc_f16(
  4815. uint32_t flags,
  4816. xnn_operator_t* sigmoid_op_out);
  4817. enum xnn_status xnn_reshape_sigmoid_nc_f16(
  4818. xnn_operator_t sigmoid_op,
  4819. size_t batch_size,
  4820. size_t channels,
  4821. size_t input_stride,
  4822. size_t output_stride,
  4823. pthreadpool_t threadpool);
  4824. enum xnn_status xnn_setup_sigmoid_nc_f16(
  4825. xnn_operator_t sigmoid_op,
  4826. const void* input,
  4827. void* output);
  4828. enum xnn_status xnn_create_sigmoid_nc_f32(
  4829. uint32_t flags,
  4830. xnn_operator_t* sigmoid_op_out);
  4831. enum xnn_status xnn_reshape_sigmoid_nc_f32(
  4832. xnn_operator_t sigmoid_op,
  4833. size_t batch_size,
  4834. size_t channels,
  4835. size_t input_stride,
  4836. size_t output_stride,
  4837. pthreadpool_t threadpool);
  4838. enum xnn_status xnn_setup_sigmoid_nc_f32(
  4839. xnn_operator_t sigmoid_op,
  4840. const float* input,
  4841. float* output);
  4842. enum xnn_status xnn_run_sigmoid_nc_f32(
  4843. size_t channels,
  4844. size_t input_stride,
  4845. size_t output_stride,
  4846. size_t batch_size,
  4847. const float* input,
  4848. float* output,
  4849. uint32_t flags,
  4850. pthreadpool_t threadpool);
  4851. enum xnn_status xnn_create_sigmoid_nc_qs8(
  4852. int8_t input_zero_point,
  4853. float input_scale,
  4854. int8_t output_zero_point,
  4855. float output_scale,
  4856. int8_t output_min,
  4857. int8_t output_max,
  4858. uint32_t flags,
  4859. xnn_operator_t* sigmoid_op_out);
  4860. enum xnn_status xnn_reshape_sigmoid_nc_qs8(
  4861. xnn_operator_t sigmoid_op,
  4862. size_t batch_size,
  4863. size_t channels,
  4864. size_t input_stride,
  4865. size_t output_stride,
  4866. pthreadpool_t threadpool);
  4867. enum xnn_status xnn_setup_sigmoid_nc_qs8(
  4868. xnn_operator_t sigmoid_op,
  4869. const int8_t* input,
  4870. int8_t* output);
  4871. enum xnn_status xnn_create_sigmoid_nc_qu8(
  4872. uint8_t input_zero_point,
  4873. float input_scale,
  4874. uint8_t output_zero_point,
  4875. float output_scale,
  4876. uint8_t output_min,
  4877. uint8_t output_max,
  4878. uint32_t flags,
  4879. xnn_operator_t* sigmoid_op_out);
  4880. enum xnn_status xnn_reshape_sigmoid_nc_qu8(
  4881. xnn_operator_t sigmoid_op,
  4882. size_t batch_size,
  4883. size_t channels,
  4884. size_t input_stride,
  4885. size_t output_stride,
  4886. pthreadpool_t threadpool);
  4887. enum xnn_status xnn_setup_sigmoid_nc_qu8(
  4888. xnn_operator_t sigmoid_op,
  4889. const uint8_t* input,
  4890. uint8_t* output);
  4891. enum xnn_status xnn_create_slice_nd_x16(
  4892. uint32_t flags,
  4893. xnn_operator_t* slice_op_out);
  4894. enum xnn_status xnn_reshape_slice_nd_x16(
  4895. xnn_operator_t slice_op,
  4896. size_t num_dims,
  4897. const size_t* input_shape,
  4898. const size_t* offsets,
  4899. const size_t* sizes,
  4900. pthreadpool_t threadpool);
  4901. enum xnn_status xnn_setup_slice_nd_x16(
  4902. xnn_operator_t slice_op,
  4903. const void* input,
  4904. void* output);
  4905. enum xnn_status xnn_create_slice_nd_x32(
  4906. uint32_t flags,
  4907. xnn_operator_t* slice_op_out);
  4908. enum xnn_status xnn_reshape_slice_nd_x32(
  4909. xnn_operator_t slice_op,
  4910. size_t num_dims,
  4911. const size_t* input_shape,
  4912. const size_t* offsets,
  4913. const size_t* sizes,
  4914. pthreadpool_t threadpool);
  4915. enum xnn_status xnn_setup_slice_nd_x32(
  4916. xnn_operator_t slice_op,
  4917. const void* input,
  4918. void* output);
  4919. enum xnn_status xnn_run_slice_nd_x32(
  4920. size_t num_dims,
  4921. const size_t* input_shape,
  4922. const size_t* offsets,
  4923. const size_t* sizes,
  4924. const void* input,
  4925. void* output,
  4926. uint32_t flags,
  4927. pthreadpool_t threadpool);
  4928. enum xnn_status xnn_create_softmax_nc_f16(
  4929. uint32_t flags,
  4930. xnn_operator_t* softmax_op_out);
  4931. enum xnn_status xnn_reshape_softmax_nc_f16(
  4932. xnn_operator_t softmax_op,
  4933. size_t channels,
  4934. size_t input_stride,
  4935. size_t output_stride,
  4936. size_t batch_size,
  4937. pthreadpool_t threadpool);
  4938. enum xnn_status xnn_setup_softmax_nc_f16(
  4939. xnn_operator_t softmax_op,
  4940. const void* input,
  4941. void* output);
  4942. enum xnn_status xnn_create_softmax_nc_f32(
  4943. uint32_t flags,
  4944. xnn_operator_t* softmax_op_out);
  4945. enum xnn_status xnn_reshape_softmax_nc_f32(
  4946. xnn_operator_t softmax_op,
  4947. size_t channels,
  4948. size_t input_stride,
  4949. size_t output_stride,
  4950. size_t batch_size,
  4951. pthreadpool_t threadpool);
  4952. enum xnn_status xnn_setup_softmax_nc_f32(
  4953. xnn_operator_t softmax_op,
  4954. const float* input,
  4955. float* output);
  4956. enum xnn_status xnn_create_softmax_nc_qu8(
  4957. float input_scale,
  4958. uint8_t output_zero_point,
  4959. float output_scale,
  4960. uint32_t flags,
  4961. xnn_operator_t* softmax_op_out);
  4962. enum xnn_status xnn_reshape_softmax_nc_qu8(
  4963. xnn_operator_t softmax_op,
  4964. size_t channels,
  4965. size_t input_stride,
  4966. size_t output_stride,
  4967. size_t batch_size,
  4968. pthreadpool_t threadpool);
  4969. enum xnn_status xnn_setup_softmax_nc_qu8(
  4970. xnn_operator_t softmax_op,
  4971. const uint8_t* input,
  4972. uint8_t* output);
  4973. enum xnn_status xnn_create_space_to_depth_nhwc_x16(
  4974. uint32_t block_size,
  4975. uint32_t flags,
  4976. xnn_operator_t* space_to_depth_op_out);
  4977. enum xnn_status xnn_reshape_space_to_depth_nhwc_x16(
  4978. xnn_operator_t space_to_depth_op,
  4979. size_t batch_size,
  4980. size_t input_height,
  4981. size_t input_width,
  4982. size_t input_channels,
  4983. size_t* output_height_out,
  4984. size_t* output_width_out,
  4985. size_t* output_channels_out,
  4986. pthreadpool_t threadpool);
  4987. enum xnn_status xnn_setup_space_to_depth_nhwc_x16(
  4988. xnn_operator_t space_to_depth_op,
  4989. const void* input,
  4990. void* output);
  4991. enum xnn_status xnn_create_space_to_depth_nhwc_x32(
  4992. uint32_t block_size,
  4993. uint32_t flags,
  4994. xnn_operator_t* space_to_depth_op_out);
  4995. enum xnn_status xnn_reshape_space_to_depth_nhwc_x32(
  4996. xnn_operator_t space_to_depth_op,
  4997. size_t batch_size,
  4998. size_t input_height,
  4999. size_t input_width,
  5000. size_t input_channels,
  5001. size_t* output_height_out,
  5002. size_t* output_width_out,
  5003. size_t* output_channels_out,
  5004. pthreadpool_t threadpool);
  5005. enum xnn_status xnn_setup_space_to_depth_nhwc_x32(
  5006. xnn_operator_t space_to_depth_op,
  5007. const void* input,
  5008. void* output);
  5009. enum xnn_status xnn_create_square_nc_f16(
  5010. uint32_t flags,
  5011. xnn_operator_t* square_op_out);
  5012. enum xnn_status xnn_reshape_square_nc_f16(
  5013. xnn_operator_t square_op,
  5014. size_t batch_size,
  5015. size_t channels,
  5016. size_t input_stride,
  5017. size_t output_stride,
  5018. pthreadpool_t threadpool);
  5019. enum xnn_status xnn_setup_square_nc_f16(
  5020. xnn_operator_t square_op,
  5021. const void* input,
  5022. void* output);
  5023. enum xnn_status xnn_create_square_nc_f32(
  5024. uint32_t flags,
  5025. xnn_operator_t* square_op_out);
  5026. enum xnn_status xnn_reshape_square_nc_f32(
  5027. xnn_operator_t square_op,
  5028. size_t batch_size,
  5029. size_t channels,
  5030. size_t input_stride,
  5031. size_t output_stride,
  5032. pthreadpool_t threadpool);
  5033. enum xnn_status xnn_setup_square_nc_f32(
  5034. xnn_operator_t square_op,
  5035. const float* input,
  5036. float* output);
  5037. enum xnn_status xnn_run_square_nc_f32(
  5038. size_t channels,
  5039. size_t input_stride,
  5040. size_t output_stride,
  5041. size_t batch_size,
  5042. const float* input,
  5043. float* output,
  5044. uint32_t flags,
  5045. pthreadpool_t threadpool);
  5046. enum xnn_status xnn_create_square_root_nc_f16(
  5047. uint32_t flags,
  5048. xnn_operator_t* sqrt_op_out);
  5049. enum xnn_status xnn_reshape_square_root_nc_f16(
  5050. xnn_operator_t sqrt_op,
  5051. size_t batch_size,
  5052. size_t channels,
  5053. size_t input_stride,
  5054. size_t output_stride,
  5055. pthreadpool_t threadpool);
  5056. enum xnn_status xnn_setup_square_root_nc_f16(
  5057. xnn_operator_t sqrt_op,
  5058. const void* input,
  5059. void* output);
  5060. enum xnn_status xnn_create_square_root_nc_f32(
  5061. uint32_t flags,
  5062. xnn_operator_t* sqrt_op_out);
  5063. enum xnn_status xnn_reshape_square_root_nc_f32(
  5064. xnn_operator_t sqrt_op,
  5065. size_t batch_size,
  5066. size_t channels,
  5067. size_t input_stride,
  5068. size_t output_stride,
  5069. pthreadpool_t threadpool);
  5070. enum xnn_status xnn_setup_square_root_nc_f32(
  5071. xnn_operator_t sqrt_op,
  5072. const float* input,
  5073. float* output);
  5074. enum xnn_status xnn_run_square_root_nc_f32(
  5075. size_t channels,
  5076. size_t input_stride,
  5077. size_t output_stride,
  5078. size_t batch_size,
  5079. const float* input,
  5080. float* output,
  5081. uint32_t flags,
  5082. pthreadpool_t threadpool);
  5083. enum xnn_status xnn_create_reciprocal_square_root_nc_f32(
  5084. uint32_t flags, xnn_operator_t* sqrt_op_out);
  5085. enum xnn_status xnn_reshape_reciprocal_square_root_nc_f32(
  5086. xnn_operator_t sqrt_op, size_t batch_size, size_t channels,
  5087. size_t input_stride, size_t output_stride, pthreadpool_t threadpool);
  5088. enum xnn_status xnn_setup_reciprocal_square_root_nc_f32(xnn_operator_t sqrt_op,
  5089. const float* input,
  5090. float* output);
  5091. enum xnn_status xnn_run_reciprocal_square_root_nc_f32(
  5092. size_t channels, size_t input_stride, size_t output_stride,
  5093. size_t batch_size, const float* input, float* output, uint32_t flags,
  5094. pthreadpool_t threadpool);
  5095. enum xnn_status xnn_create_squared_difference_nd_f16(
  5096. uint32_t flags,
  5097. xnn_operator_t* squared_difference_op_out);
  5098. enum xnn_status xnn_reshape_squared_difference_nd_f16(
  5099. xnn_operator_t squared_difference_op,
  5100. size_t num_input1_dims,
  5101. const size_t* input1_shape,
  5102. size_t num_input2_dims,
  5103. const size_t* input2_shape,
  5104. pthreadpool_t threadpool);
  5105. enum xnn_status xnn_setup_squared_difference_nd_f16(
  5106. xnn_operator_t squared_difference_op,
  5107. const void* input1,
  5108. const void* input2,
  5109. void* output);
  5110. enum xnn_status xnn_create_squared_difference_nd_f32(
  5111. uint32_t flags,
  5112. xnn_operator_t* squared_difference_op_out);
  5113. enum xnn_status xnn_reshape_squared_difference_nd_f32(
  5114. xnn_operator_t squared_difference_op,
  5115. size_t num_input1_dims,
  5116. const size_t* input1_shape,
  5117. size_t num_input2_dims,
  5118. const size_t* input2_shape,
  5119. pthreadpool_t threadpool);
  5120. enum xnn_status xnn_setup_squared_difference_nd_f32(
  5121. xnn_operator_t squared_difference_op,
  5122. const float* input1,
  5123. const float* input2,
  5124. float* output);
  5125. enum xnn_status xnn_run_squared_difference_nd_f32(
  5126. size_t num_input1_dims,
  5127. const size_t* input1_shape,
  5128. size_t num_input2_dims,
  5129. const size_t* input2_shape,
  5130. const float* input1,
  5131. const float* input2,
  5132. float* output,
  5133. uint32_t flags,
  5134. pthreadpool_t threadpool);
  5135. enum xnn_status xnn_create_subtract_nd_f16(
  5136. float output_min,
  5137. float output_max,
  5138. uint32_t flags,
  5139. xnn_operator_t* subtract_op_out);
  5140. enum xnn_status xnn_reshape_subtract_nd_f16(
  5141. xnn_operator_t subtract_op,
  5142. size_t num_input1_dims,
  5143. const size_t* input1_shape,
  5144. size_t num_input2_dims,
  5145. const size_t* input2_shape,
  5146. pthreadpool_t threadpool);
  5147. enum xnn_status xnn_setup_subtract_nd_f16(
  5148. xnn_operator_t subtract_op,
  5149. const void* input1,
  5150. const void* input2,
  5151. void* output);
  5152. enum xnn_status xnn_create_subtract_nd_f32(
  5153. float output_min,
  5154. float output_max,
  5155. uint32_t flags,
  5156. xnn_operator_t* subtract_op_out);
  5157. enum xnn_status xnn_reshape_subtract_nd_f32(
  5158. xnn_operator_t subtract_op,
  5159. size_t num_input1_dims,
  5160. const size_t* input1_shape,
  5161. size_t num_input2_dims,
  5162. const size_t* input2_shape,
  5163. pthreadpool_t threadpool);
  5164. enum xnn_status xnn_setup_subtract_nd_f32(
  5165. xnn_operator_t subtract_op,
  5166. const float* input1,
  5167. const float* input2,
  5168. float* output);
  5169. enum xnn_status xnn_run_subtract_nd_f32(
  5170. size_t num_input1_dims,
  5171. const size_t* input1_shape,
  5172. size_t num_input2_dims,
  5173. const size_t* input2_shape,
  5174. const float* input1,
  5175. const float* input2,
  5176. float* output,
  5177. float output_min,
  5178. float output_max,
  5179. uint32_t flags,
  5180. pthreadpool_t threadpool);
  5181. enum xnn_status xnn_create_subtract_nd_qs8(
  5182. int8_t input1_zero_point,
  5183. float input1_scale,
  5184. int8_t input2_zero_point,
  5185. float input2_scale,
  5186. int8_t output_zero_point,
  5187. float output_scale,
  5188. int8_t output_min,
  5189. int8_t output_max,
  5190. uint32_t flags,
  5191. xnn_operator_t* subtract_op_out);
  5192. enum xnn_status xnn_reshape_subtract_nd_qs8(
  5193. xnn_operator_t subtract_op,
  5194. size_t num_input1_dims,
  5195. const size_t* input1_shape,
  5196. size_t num_input2_dims,
  5197. const size_t* input2_shape,
  5198. pthreadpool_t threadpool);
  5199. enum xnn_status xnn_setup_subtract_nd_qs8(
  5200. xnn_operator_t subtract_op,
  5201. const int8_t* input1,
  5202. const int8_t* input2,
  5203. int8_t* output);
  5204. enum xnn_status xnn_run_subtract_nd_qs8(
  5205. size_t num_input1_dims,
  5206. const size_t* input1_shape,
  5207. int8_t input1_zero_point,
  5208. float input1_scale,
  5209. size_t num_input2_dims,
  5210. const size_t* input2_shape,
  5211. int8_t input2_zero_point,
  5212. float input2_scale,
  5213. const int8_t* input1,
  5214. const int8_t* input2,
  5215. int8_t* output,
  5216. int8_t output_zero_point,
  5217. float output_scale,
  5218. int8_t output_min,
  5219. int8_t output_max,
  5220. uint32_t flags,
  5221. pthreadpool_t threadpool);
  5222. enum xnn_status xnn_create_subtract_nd_qu8(
  5223. uint8_t input1_zero_point,
  5224. float input1_scale,
  5225. uint8_t input2_zero_point,
  5226. float input2_scale,
  5227. uint8_t output_zero_point,
  5228. float output_scale,
  5229. uint8_t output_min,
  5230. uint8_t output_max,
  5231. uint32_t flags,
  5232. xnn_operator_t* subtract_op_out);
  5233. enum xnn_status xnn_reshape_subtract_nd_qu8(
  5234. xnn_operator_t subtract_op,
  5235. size_t num_input1_dims,
  5236. const size_t* input1_shape,
  5237. size_t num_input2_dims,
  5238. const size_t* input2_shape,
  5239. pthreadpool_t threadpool);
  5240. enum xnn_status xnn_setup_subtract_nd_qu8(
  5241. xnn_operator_t subtract_op,
  5242. const uint8_t* input1,
  5243. const uint8_t* input2,
  5244. uint8_t* output);
  5245. enum xnn_status xnn_run_subtract_nd_qu8(
  5246. size_t num_input1_dims,
  5247. const size_t* input1_shape,
  5248. uint8_t input1_zero_point,
  5249. float input1_scale,
  5250. size_t num_input2_dims,
  5251. const size_t* input2_shape,
  5252. uint8_t input2_zero_point,
  5253. float input2_scale,
  5254. const uint8_t* input1,
  5255. const uint8_t* input2,
  5256. uint8_t* output,
  5257. uint8_t output_zero_point,
  5258. float output_scale,
  5259. uint8_t output_min,
  5260. uint8_t output_max,
  5261. uint32_t flags,
  5262. pthreadpool_t threadpool);
  5263. enum xnn_status xnn_create_tanh_nc_f16(
  5264. uint32_t flags,
  5265. xnn_operator_t* tanh_op_out);
  5266. enum xnn_status xnn_reshape_tanh_nc_f16(
  5267. xnn_operator_t tanh_op,
  5268. size_t batch_size,
  5269. size_t channels,
  5270. size_t input_stride,
  5271. size_t output_stride,
  5272. pthreadpool_t threadpool);
  5273. enum xnn_status xnn_setup_tanh_nc_f16(
  5274. xnn_operator_t tanh_op,
  5275. const void* input,
  5276. void* output);
  5277. enum xnn_status xnn_create_tanh_nc_f32(
  5278. uint32_t flags,
  5279. xnn_operator_t* tanh_op_out);
  5280. enum xnn_status xnn_reshape_tanh_nc_f32(
  5281. xnn_operator_t tanh_op,
  5282. size_t batch_size,
  5283. size_t channels,
  5284. size_t input_stride,
  5285. size_t output_stride,
  5286. pthreadpool_t threadpool);
  5287. enum xnn_status xnn_setup_tanh_nc_f32(
  5288. xnn_operator_t tanh_op,
  5289. const float* input,
  5290. float* output);
  5291. enum xnn_status xnn_run_tanh_nc_f32(
  5292. size_t channels,
  5293. size_t input_stride,
  5294. size_t output_stride,
  5295. size_t batch_size,
  5296. const float* input,
  5297. float* output,
  5298. uint32_t flags,
  5299. pthreadpool_t threadpool);
  5300. enum xnn_status xnn_create_tanh_nc_qs8(
  5301. int8_t input_zero_point,
  5302. float input_scale,
  5303. int8_t output_zero_point,
  5304. float output_scale,
  5305. int8_t output_min,
  5306. int8_t output_max,
  5307. uint32_t flags,
  5308. xnn_operator_t* tanh_op_out);
  5309. enum xnn_status xnn_reshape_tanh_nc_qs8(
  5310. xnn_operator_t tanh_op,
  5311. size_t batch_size,
  5312. size_t channels,
  5313. size_t input_stride,
  5314. size_t output_stride,
  5315. pthreadpool_t threadpool);
  5316. enum xnn_status xnn_setup_tanh_nc_qs8(
  5317. xnn_operator_t tanh_op,
  5318. const int8_t* input,
  5319. int8_t* output);
  5320. enum xnn_status xnn_create_tanh_nc_qu8(
  5321. uint8_t input_zero_point,
  5322. float input_scale,
  5323. uint8_t output_zero_point,
  5324. float output_scale,
  5325. uint8_t output_min,
  5326. uint8_t output_max,
  5327. uint32_t flags,
  5328. xnn_operator_t* tanh_op_out);
  5329. enum xnn_status xnn_reshape_tanh_nc_qu8(
  5330. xnn_operator_t tanh_op,
  5331. size_t batch_size,
  5332. size_t channels,
  5333. size_t input_stride,
  5334. size_t output_stride,
  5335. pthreadpool_t threadpool);
  5336. enum xnn_status xnn_setup_tanh_nc_qu8(
  5337. xnn_operator_t tanh_op,
  5338. const uint8_t* input,
  5339. uint8_t* output);
  5340. enum xnn_status xnn_create_transpose_nd_x8(
  5341. uint32_t flags,
  5342. xnn_operator_t* transpose_op_out);
  5343. enum xnn_status xnn_reshape_transpose_nd_x8(
  5344. xnn_operator_t transpose_op,
  5345. size_t num_dims,
  5346. const size_t* input_shape,
  5347. const size_t* output_perm,
  5348. pthreadpool_t threadpool);
  5349. enum xnn_status xnn_setup_transpose_nd_x8(
  5350. xnn_operator_t transpose_op,
  5351. const void* input,
  5352. void* output);
  5353. enum xnn_status xnn_run_transpose_nd_x8(
  5354. const void* input,
  5355. void* output,
  5356. size_t num_dims,
  5357. const size_t* input_shape,
  5358. const size_t* output_perm,
  5359. uint32_t flags,
  5360. pthreadpool_t threadpool);
  5361. enum xnn_status xnn_create_transpose_nd_x16(
  5362. uint32_t flags,
  5363. xnn_operator_t* transpose_op_out);
  5364. enum xnn_status xnn_reshape_transpose_nd_x16(
  5365. xnn_operator_t transpose_op,
  5366. size_t num_dims,
  5367. const size_t* input_shape,
  5368. const size_t* output_perm,
  5369. pthreadpool_t threadpool);
  5370. enum xnn_status xnn_setup_transpose_nd_x16(
  5371. xnn_operator_t transpose_op,
  5372. const void* input,
  5373. void* output);
  5374. enum xnn_status xnn_run_transpose_nd_x16(
  5375. const void* input,
  5376. void* output,
  5377. size_t num_dims,
  5378. const size_t* input_shape,
  5379. const size_t* output_perm,
  5380. uint32_t flags,
  5381. pthreadpool_t threadpool);
  5382. enum xnn_status xnn_create_transpose_nd_x32(
  5383. uint32_t flags,
  5384. xnn_operator_t* transpose_op_out);
  5385. enum xnn_status xnn_reshape_transpose_nd_x32(
  5386. xnn_operator_t transpose_op,
  5387. size_t num_dims,
  5388. const size_t* input_shape,
  5389. const size_t* output_perm,
  5390. pthreadpool_t threadpool);
  5391. enum xnn_status xnn_setup_transpose_nd_x32(
  5392. xnn_operator_t transpose_op,
  5393. const void* input,
  5394. void* output);
  5395. enum xnn_status xnn_run_transpose_nd_x32(
  5396. const void* input,
  5397. void* output,
  5398. size_t num_dims,
  5399. const size_t* input_shape,
  5400. const size_t* output_perm,
  5401. uint32_t flags,
  5402. pthreadpool_t threadpool);
  5403. enum xnn_status xnn_create_transpose_nd_x64(
  5404. uint32_t flags,
  5405. xnn_operator_t* transpose_op_out);
  5406. enum xnn_status xnn_reshape_transpose_nd_x64(
  5407. xnn_operator_t transpose_op,
  5408. size_t num_dims,
  5409. const size_t* input_shape,
  5410. const size_t* output_perm,
  5411. pthreadpool_t threadpool);
  5412. enum xnn_status xnn_setup_transpose_nd_x64(
  5413. xnn_operator_t transpose_op,
  5414. const void* input,
  5415. void* output);
  5416. enum xnn_status xnn_run_transpose_nd_x64(
  5417. const void* input,
  5418. void* output,
  5419. size_t num_dims,
  5420. const size_t* input_shape,
  5421. const size_t* output_perm,
  5422. uint32_t flags,
  5423. pthreadpool_t threadpool);
  5424. enum xnn_status xnn_create_truncation_nc_f16(
  5425. uint32_t flags,
  5426. xnn_operator_t* truncation_op_out);
  5427. enum xnn_status xnn_reshape_truncation_nc_f16(
  5428. xnn_operator_t truncation_op,
  5429. size_t batch_size,
  5430. size_t channels,
  5431. size_t input_stride,
  5432. size_t output_stride,
  5433. pthreadpool_t threadpool);
  5434. enum xnn_status xnn_setup_truncation_nc_f16(
  5435. xnn_operator_t truncation_op,
  5436. const void* input,
  5437. void* output);
  5438. enum xnn_status xnn_create_truncation_nc_f32(
  5439. uint32_t flags,
  5440. xnn_operator_t* truncation_op_out);
  5441. enum xnn_status xnn_reshape_truncation_nc_f32(
  5442. xnn_operator_t truncation_op,
  5443. size_t batch_size,
  5444. size_t channels,
  5445. size_t input_stride,
  5446. size_t output_stride,
  5447. pthreadpool_t threadpool);
  5448. enum xnn_status xnn_setup_truncation_nc_f32(
  5449. xnn_operator_t truncation_op,
  5450. const float* input,
  5451. float* output);
  5452. enum xnn_status xnn_run_truncation_nc_f32(
  5453. size_t channels,
  5454. size_t input_stride,
  5455. size_t output_stride,
  5456. size_t batch_size,
  5457. const float* input,
  5458. float* output,
  5459. uint32_t flags,
  5460. pthreadpool_t threadpool);
  5461. enum xnn_status xnn_create_unpooling2d_nhwc_x32(
  5462. uint32_t input_padding_top,
  5463. uint32_t input_padding_right,
  5464. uint32_t input_padding_bottom,
  5465. uint32_t input_padding_left,
  5466. uint32_t pooling_height,
  5467. uint32_t pooling_width,
  5468. size_t channels,
  5469. size_t input_pixel_stride,
  5470. size_t output_pixel_stride,
  5471. uint32_t flags,
  5472. xnn_operator_t* unpooling_op_out);
  5473. enum xnn_status xnn_reshape_unpooling2d_nhwc_x32(
  5474. xnn_operator_t unpooling_op,
  5475. size_t batch_size,
  5476. size_t input_height,
  5477. size_t input_width,
  5478. size_t* output_height_out,
  5479. size_t* output_width_out,
  5480. pthreadpool_t threadpool);
  5481. enum xnn_status xnn_setup_unpooling2d_nhwc_x32(
  5482. xnn_operator_t unpooling_op,
  5483. const void* input,
  5484. const uint32_t* index,
  5485. void* output);
  5486. enum xnn_status xnn_create_slice_nd_x8(
  5487. uint32_t flags,
  5488. xnn_operator_t* slice_op_out);
  5489. enum xnn_status xnn_reshape_slice_nd_x8(
  5490. xnn_operator_t slice_op,
  5491. size_t num_dims,
  5492. const size_t* input_shape,
  5493. const size_t* offsets,
  5494. const size_t* sizes,
  5495. pthreadpool_t threadpool);
  5496. enum xnn_status xnn_setup_slice_nd_x8(
  5497. xnn_operator_t slice_op,
  5498. const void* input,
  5499. void* output);
  5500. enum xnn_status xnn_create_space_to_depth_nhwc_x8(
  5501. uint32_t block_size,
  5502. uint32_t flags,
  5503. xnn_operator_t* space_to_depth_op_out);
  5504. enum xnn_status xnn_reshape_space_to_depth_nhwc_x8(
  5505. xnn_operator_t space_to_depth_op,
  5506. size_t batch_size,
  5507. size_t input_height,
  5508. size_t input_width,
  5509. size_t input_channels,
  5510. size_t* output_height_out,
  5511. size_t* output_width_out,
  5512. size_t* output_channels_out,
  5513. pthreadpool_t threadpool);
  5514. enum xnn_status xnn_setup_space_to_depth_nhwc_x8(
  5515. xnn_operator_t space_to_depth_op,
  5516. const void* input,
  5517. void* output);
  5518. #ifdef __cplusplus
  5519. } // extern "C"
  5520. #endif