triton_utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. # mypy: ignore-errors
  2. import unittest
  3. from torch.testing._internal.inductor_utils import HAS_CUDA
  4. from torch.utils._triton import has_triton
  5. requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
  6. if has_triton():
  7. import triton
  8. from triton import language as tl
  9. # Define here so that multiple tests can take advantage of it
  10. @triton.jit
  11. def add_kernel(
  12. in_ptr0,
  13. in_ptr1,
  14. out_ptr,
  15. n_elements,
  16. BLOCK_SIZE: "tl.constexpr",
  17. ):
  18. pid = tl.program_id(axis=0)
  19. block_start = pid * BLOCK_SIZE
  20. offsets = block_start + tl.arange(0, BLOCK_SIZE)
  21. mask = offsets < n_elements
  22. x = tl.load(in_ptr0 + offsets, mask=mask)
  23. y = tl.load(in_ptr1 + offsets, mask=mask)
  24. output = x + y
  25. tl.store(out_ptr + offsets, output, mask=mask)
  26. @triton.jit
  27. def add_kernel_with_optional_param(
  28. in_ptr0,
  29. in_ptr1,
  30. out_ptr,
  31. n_elements,
  32. ARGS_PASSED: "tl.constexpr",
  33. BLOCK_SIZE: "tl.constexpr",
  34. ):
  35. pid = tl.program_id(axis=0)
  36. block_start = pid * BLOCK_SIZE
  37. offsets = block_start + tl.arange(0, BLOCK_SIZE)
  38. mask = offsets < n_elements
  39. x = tl.load(in_ptr0 + offsets, mask=mask)
  40. if ARGS_PASSED == "two":
  41. y = tl.load(in_ptr1 + offsets, mask=mask)
  42. output = x + y
  43. else:
  44. output = x
  45. tl.store(out_ptr + offsets, output, mask=mask)
  46. @triton.autotune(
  47. configs=[
  48. triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8),
  49. triton.Config({"BLOCK_SIZE": 128}, num_stages=4, num_warps=4),
  50. triton.Config({"BLOCK_SIZE": 64}, num_stages=3, num_warps=8),
  51. triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
  52. ],
  53. key=[],
  54. )
  55. @triton.jit
  56. def add_kernel_autotuned(
  57. in_ptr0,
  58. in_ptr1,
  59. out_ptr,
  60. n_elements,
  61. BLOCK_SIZE: "tl.constexpr",
  62. ):
  63. pid = tl.program_id(axis=0)
  64. block_start = pid * BLOCK_SIZE
  65. offsets = block_start + tl.arange(0, BLOCK_SIZE)
  66. mask = offsets < n_elements
  67. x = tl.load(in_ptr0 + offsets, mask=mask)
  68. y = tl.load(in_ptr1 + offsets, mask=mask)
  69. output = x + y
  70. tl.store(out_ptr + offsets, output, mask=mask)
  71. @triton.autotune(
  72. configs=[
  73. triton.Config(
  74. {"BLOCK_SIZE_X": 128, "BLOCK_SIZE_Y": 128}, num_stages=3, num_warps=8
  75. ),
  76. triton.Config(
  77. {"BLOCK_SIZE_X": 128, "BLOCK_SIZE_Y": 128}, num_stages=4, num_warps=4
  78. ),
  79. triton.Config(
  80. {"BLOCK_SIZE_X": 64, "BLOCK_SIZE_Y": 64}, num_stages=3, num_warps=8
  81. ),
  82. triton.Config(
  83. {"BLOCK_SIZE_X": 64, "BLOCK_SIZE_Y": 64}, num_stages=4, num_warps=4
  84. ),
  85. ],
  86. key=[],
  87. )
  88. @triton.jit
  89. def add_kernel_2d_autotuned(
  90. in_ptr0,
  91. in_ptr1,
  92. out_ptr,
  93. x_elements,
  94. y_elements,
  95. BLOCK_SIZE_X: "tl.constexpr",
  96. BLOCK_SIZE_Y: "tl.constexpr",
  97. ):
  98. xoffset = tl.program_id(0) * BLOCK_SIZE_X
  99. xindex = xoffset + tl.arange(0, BLOCK_SIZE_X)[:, None]
  100. xmask = xindex < x_elements
  101. yoffset = tl.program_id(1) * BLOCK_SIZE_Y
  102. yindex = yoffset + tl.arange(0, BLOCK_SIZE_Y)[None, :]
  103. ymask = yindex < y_elements
  104. x1 = xindex
  105. y0 = yindex
  106. tmp0 = tl.load(in_ptr0 + (x1 + (x_elements * y0)), xmask & ymask)
  107. tmp1 = tl.load(in_ptr0 + (y0 + (y_elements * x1)), xmask & ymask)
  108. tmp2 = tmp0 + tmp1
  109. tl.store(out_ptr + (x1 + (x_elements * y0)), tmp2, xmask & ymask)
  110. @triton.jit
  111. def add_kernel_with_scaling(
  112. in_ptr0,
  113. in_ptr1,
  114. out_ptr,
  115. n_elements,
  116. scaling_factor,
  117. BLOCK_SIZE: "tl.constexpr",
  118. ):
  119. pid = tl.program_id(axis=0)
  120. block_start = pid * BLOCK_SIZE
  121. offsets = block_start + tl.arange(0, BLOCK_SIZE)
  122. mask = offsets < n_elements
  123. x = tl.load(in_ptr0 + offsets, mask=mask)
  124. y = tl.load(in_ptr1 + offsets, mask=mask)
  125. output = (x + y) * scaling_factor
  126. tl.store(out_ptr + offsets, output, mask=mask)
  127. @triton.jit
  128. def mul2_kernel(
  129. in_ptr0,
  130. out_ptr,
  131. n_elements,
  132. BLOCK_SIZE: "tl.constexpr",
  133. ):
  134. pid = tl.program_id(axis=0)
  135. block_start = pid * BLOCK_SIZE
  136. offsets = block_start + tl.arange(0, BLOCK_SIZE)
  137. mask = offsets < n_elements
  138. x = tl.load(in_ptr0 + offsets, mask=mask)
  139. output = 2 * x
  140. tl.store(out_ptr + offsets, output, mask=mask)
  141. @triton.jit
  142. def mul2_inplace_kernel(
  143. ptr,
  144. n_elements,
  145. BLOCK_SIZE: "tl.constexpr",
  146. ):
  147. pid = tl.program_id(axis=0)
  148. block_start = pid * BLOCK_SIZE
  149. offsets = block_start + tl.arange(0, BLOCK_SIZE)
  150. mask = offsets < n_elements
  151. x = tl.load(ptr + offsets, mask=mask)
  152. output = 2 * x
  153. tl.store(ptr + offsets, output, mask=mask)
  154. @triton.jit
  155. def zero_negs(x):
  156. return tl.where(x >= 0, x, 0)
  157. @triton.jit
  158. def indirection_kernel(
  159. in_ptr0,
  160. out_ptr,
  161. n_elements,
  162. BLOCK_SIZE: "tl.constexpr",
  163. ACTIVATION: "tl.constexpr",
  164. ):
  165. pid = tl.program_id(axis=0)
  166. block_start = pid * BLOCK_SIZE
  167. offsets = block_start + tl.arange(0, BLOCK_SIZE)
  168. mask = offsets < n_elements
  169. if ACTIVATION == "mul2_inplace_kernel":
  170. mul2_inplace_kernel(in_ptr0, n_elements, BLOCK_SIZE=BLOCK_SIZE)
  171. elif ACTIVATION == "add_kernel":
  172. add_kernel(in_ptr0, in_ptr0, out_ptr, n_elements, BLOCK_SIZE=BLOCK_SIZE)
  173. x = tl.load(in_ptr0 + offsets, mask=mask)
  174. tl.store(out_ptr + offsets, x, mask=mask)
  175. @triton.jit
  176. def double_strided_kernel(
  177. in_ptr,
  178. out_ptr,
  179. in_y_stride,
  180. out_y_stride,
  181. X_BLOCK_SIZE: "tl.constexpr",
  182. Y_BLOCK_SIZE: "tl.constexpr",
  183. ):
  184. xid = tl.program_id(axis=0)
  185. yid = tl.program_id(axis=1)
  186. x_start = xid * X_BLOCK_SIZE
  187. y_start = yid * Y_BLOCK_SIZE
  188. x_offsets = x_start + tl.arange(0, X_BLOCK_SIZE)
  189. y_offsets = y_start + tl.arange(0, Y_BLOCK_SIZE)
  190. src_offsets = y_offsets[:, None] * in_y_stride + x_offsets[None, :]
  191. dst_offsets = y_offsets[:, None] * out_y_stride + x_offsets[None, :]
  192. src = tl.load(in_ptr + src_offsets)
  193. tl.store(out_ptr + dst_offsets, src * 2.0)
  194. @triton.jit
  195. def inline_asm_kernel(X, Y, Z, n: "tl.constexpr", BLOCK: "tl.constexpr"):
  196. x = tl.load(X + tl.arange(0, BLOCK))
  197. y = tl.load(Y + tl.arange(0, BLOCK))
  198. s = tl.full([BLOCK], n, tl.int32)
  199. z = tl.inline_asm_elementwise(
  200. "shf.l.wrap.b32 $0, $1, $2, $3;",
  201. "=r,r, r, r",
  202. [x, y, s],
  203. dtype=tl.int32,
  204. is_pure=True,
  205. pack=1,
  206. )
  207. tl.store(Z + tl.arange(0, BLOCK), z)
  208. @triton.jit
  209. def add_kernel_with_block_ptr(
  210. x_ptr,
  211. y_ptr,
  212. output_ptr,
  213. n_elements,
  214. BLOCK_SIZE: tl.constexpr,
  215. ):
  216. pid = tl.program_id(axis=0)
  217. block_start = pid * BLOCK_SIZE
  218. x = tl.load(
  219. tl.make_block_ptr(
  220. base=x_ptr,
  221. shape=[n_elements],
  222. strides=[1],
  223. offsets=[block_start],
  224. block_shape=[BLOCK_SIZE],
  225. order=[0],
  226. ),
  227. boundary_check=[0],
  228. )
  229. y = tl.load(
  230. tl.make_block_ptr(
  231. base=y_ptr,
  232. shape=[n_elements],
  233. strides=[1],
  234. offsets=[block_start],
  235. block_shape=[BLOCK_SIZE],
  236. order=[0],
  237. ),
  238. boundary_check=[0],
  239. )
  240. output = x + y
  241. tl.store(
  242. tl.make_block_ptr(
  243. base=output_ptr,
  244. shape=[n_elements],
  245. strides=[1],
  246. offsets=[block_start],
  247. block_shape=[BLOCK_SIZE],
  248. order=[0],
  249. ),
  250. output,
  251. boundary_check=[0],
  252. )
  253. @triton.jit
  254. def kernel_with_block_ptr_2d(
  255. x_ptr,
  256. output_ptr,
  257. n_elements,
  258. BLOCK_SIZE: tl.constexpr,
  259. ):
  260. pid = tl.program_id(axis=0)
  261. block_start = pid * BLOCK_SIZE
  262. x = tl.load(
  263. tl.make_block_ptr(
  264. base=x_ptr,
  265. shape=[n_elements, 1],
  266. strides=[1, 1],
  267. offsets=[block_start, 0],
  268. block_shape=[BLOCK_SIZE, 1],
  269. order=[1, 0],
  270. ),
  271. boundary_check=[0],
  272. )
  273. output = x
  274. tl.store(
  275. tl.make_block_ptr(
  276. base=output_ptr,
  277. shape=[n_elements, 1],
  278. strides=[1, 1],
  279. offsets=[block_start, 0],
  280. block_shape=[BLOCK_SIZE, 1],
  281. order=[1, 0],
  282. ),
  283. output,
  284. boundary_check=[0],
  285. )
  286. from triton.language import load, store
  287. @triton.jit
  288. def add_kernel_with_import(
  289. in_ptr0,
  290. in_ptr1,
  291. out_ptr,
  292. n_elements,
  293. BLOCK_SIZE: "tl.constexpr",
  294. ):
  295. pid = tl.program_id(axis=0)
  296. block_start = pid * BLOCK_SIZE
  297. offsets = block_start + tl.arange(0, BLOCK_SIZE)
  298. mask = offsets < n_elements
  299. x = load(in_ptr0 + offsets, mask=mask)
  300. y = load(in_ptr1 + offsets, mask=mask)
  301. output = x + y
  302. store(out_ptr + offsets, output, mask=mask)
  303. @triton.jit
  304. def cond_op_kernel(
  305. in_ptr0,
  306. in_ptr1,
  307. out_ptr,
  308. n_elements,
  309. BLOCK_SIZE: "tl.constexpr",
  310. ):
  311. pid = tl.program_id(axis=0)
  312. block_start = pid * BLOCK_SIZE
  313. offsets = block_start + tl.arange(0, BLOCK_SIZE)
  314. mask = offsets < n_elements
  315. x = tl.load(in_ptr0 + offsets, mask=mask)
  316. y = tl.load(in_ptr1 + offsets, mask=mask)
  317. if tl.program_id(0) == 0:
  318. output = x + y
  319. else:
  320. output = x * y
  321. tl.store(out_ptr + offsets, output, mask=mask)
  322. @triton.jit
  323. def atomic_add_kernel(
  324. in_ptr0,
  325. in_ptr1,
  326. out_ptr,
  327. n_elements,
  328. BLOCK_SIZE: "tl.constexpr",
  329. ):
  330. pid = tl.program_id(axis=0)
  331. block_start = pid * BLOCK_SIZE
  332. offsets = block_start + tl.arange(0, BLOCK_SIZE)
  333. mask = offsets < n_elements
  334. x = tl.load(in_ptr0 + offsets, mask=mask)
  335. y = tl.load(in_ptr1 + offsets, mask=mask)
  336. output = x + y
  337. tl.atomic_add(out_ptr + offsets, output, mask=mask)
  338. @triton.jit
  339. def add_4_times_kernel(
  340. in_ptr0,
  341. in_ptr1,
  342. out_ptr,
  343. n_elements,
  344. BLOCK_SIZE: "tl.constexpr",
  345. ):
  346. pid = tl.program_id(axis=0)
  347. block_start = pid * BLOCK_SIZE
  348. offsets = block_start + tl.arange(0, BLOCK_SIZE)
  349. mask = offsets < n_elements
  350. x = tl.load(in_ptr0 + offsets, mask=mask)
  351. y = tl.load(in_ptr1 + offsets, mask=mask)
  352. for i in range(2):
  353. output = x + y
  354. tl.store(out_ptr + offsets, output, mask=mask)
  355. i = 2
  356. while i > 0:
  357. i -= 1
  358. output = x + y
  359. tl.store(out_ptr + offsets, output, mask=mask)
  360. @triton.jit
  361. def add_kernel_out_of_order_fn2(
  362. in_ptr0,
  363. in_ptr1,
  364. n_elements,
  365. out_ptr,
  366. BLOCK_SIZE: "tl.constexpr",
  367. ):
  368. pid = tl.program_id(axis=0)
  369. block_start = pid * BLOCK_SIZE
  370. offsets = block_start + tl.arange(0, BLOCK_SIZE)
  371. mask = offsets < n_elements
  372. x = tl.load(in_ptr0 + offsets, mask=mask)
  373. y = tl.load(in_ptr1 + offsets, mask=mask)
  374. output = x + y
  375. tl.store(out_ptr + offsets, output, mask=mask)