hipify_python.py 45 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159
  1. #!/usr/bin/env python3
  2. # mypy: allow-untyped-defs
  3. """ The Python Hipify script.
  4. ##
  5. # Copyright (c) 2015-2016 Advanced Micro Devices, Inc. All rights reserved.
  6. # 2017-2018 Advanced Micro Devices, Inc. and
  7. # Facebook Inc. All rights reserved.
  8. #
  9. # Permission is hereby granted, free of charge, to any person obtaining a copy
  10. # of this software and associated documentation files (the "Software"), to deal
  11. # in the Software without restriction, including without limitation the rights
  12. # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  13. # copies of the Software, and to permit persons to whom the Software is
  14. # furnished to do so, subject to the following conditions:
  15. #
  16. # The above copyright notice and this permission notice shall be included in
  17. # all copies or substantial portions of the Software.
  18. #
  19. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  20. # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  21. # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  22. # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  23. # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  24. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
  25. # THE SOFTWARE.
  26. """
  27. import argparse
  28. import fnmatch
  29. import re
  30. import shutil
  31. import sys
  32. import os
  33. from . import constants
  34. from .cuda_to_hip_mappings import CUDA_TO_HIP_MAPPINGS
  35. from .cuda_to_hip_mappings import MATH_TRANSPILATIONS
  36. from typing import Dict, List, Iterator, Optional
  37. from collections.abc import Mapping, Iterable
  38. from enum import Enum
  39. class CurrentState(Enum):
  40. INITIALIZED = 1
  41. DONE = 2
  42. class HipifyResult:
  43. def __init__(self, current_state, hipified_path):
  44. self.current_state = current_state
  45. self.hipified_path = hipified_path
  46. self.status = ""
  47. def __str__(self):
  48. return (f"HipifyResult:: current_state: {self.current_state}, hipified_path : {self.hipified_path}, status: {self.status}")
  49. HipifyFinalResult = Dict[str, HipifyResult]
  50. HIPIFY_C_BREADCRUMB = "// !!! This is a file automatically generated by hipify!!!\n"
  51. HIPIFY_FINAL_RESULT: HipifyFinalResult = {}
  52. # Hardcode the PyTorch template map
  53. """This dictionary provides the mapping from PyTorch kernel template types
  54. to their actual types."""
  55. PYTORCH_TEMPLATE_MAP = {"Dtype": "scalar_t", "T": "scalar_t"}
  56. __all__ = ['InputError', 'openf', 'bcolors', 'GeneratedFileCleaner', 'match_extensions', 'matched_files_iter',
  57. 'preprocess_file_and_save_result', 'compute_stats', 'add_dim3', 'processKernelLaunches', 'find_closure_group',
  58. 'find_bracket_group', 'find_parentheses_group', 'replace_math_functions', 'hip_header_magic', 'replace_extern_shared',
  59. 'get_hip_file_path', 'is_out_of_place', 'is_pytorch_file', 'is_cusparse_file', 'is_special_file', 'is_caffe2_gpu_file',
  60. 'is_caffe2_gpu_file', 'Trie', 'preprocessor', 'file_specific_replacement', 'file_add_header',
  61. 'fix_static_global_kernels', 'extract_arguments', 'str2bool', 'CurrentState', 'HipifyResult', 'hipify']
  62. class InputError(Exception):
  63. # Exception raised for errors in the input.
  64. def __init__(self, message):
  65. super().__init__(message)
  66. self.message = message
  67. def __str__(self):
  68. return f"Input error: {self.message}"
  69. def openf(filename, mode):
  70. return open(filename, mode, errors='ignore')
  71. # Color coding for printing
  72. class bcolors:
  73. HEADER = '\033[95m'
  74. OKBLUE = '\033[94m'
  75. OKGREEN = '\033[92m'
  76. WARNING = '\033[93m'
  77. FAIL = '\033[91m'
  78. ENDC = '\033[0m'
  79. BOLD = '\033[1m'
  80. UNDERLINE = '\033[4m'
  81. # To the programmer, the output of hipify most likely are intermediates.
  82. # This class allows users of hipify to ask for a cleanup by running the
  83. # hipify and compilation in a with instantiating this context manager class
  84. # with keep_intermediates=False.
  85. # The main usecase is the cpp_extensions, specifically the load method.
  86. # It is a good idea to keep intermediates (in case of errors or to
  87. # not recompile unchanged files), but in cases where you don't want to
  88. # keep them (e.g. in the CI), this can be used to remove files.
  89. class GeneratedFileCleaner:
  90. """Context Manager to clean up generated files"""
  91. def __init__(self, keep_intermediates=False):
  92. self.keep_intermediates = keep_intermediates
  93. self.files_to_clean = set()
  94. self.dirs_to_clean = []
  95. def __enter__(self):
  96. return self
  97. def open(self, fn, *args, **kwargs):
  98. if not os.path.exists(fn):
  99. self.files_to_clean.add(os.path.abspath(fn))
  100. return open(fn, *args, **kwargs)
  101. def makedirs(self, dn, exist_ok=False):
  102. parent, n = os.path.split(dn)
  103. if not n:
  104. parent, n = os.path.split(parent)
  105. if parent and n and not os.path.exists(parent):
  106. self.makedirs(parent, exist_ok=True)
  107. if not os.path.isdir(dn) or not exist_ok:
  108. os.mkdir(dn)
  109. self.dirs_to_clean.append(os.path.abspath(dn))
  110. def __exit__(self, type, value, traceback):
  111. if not self.keep_intermediates:
  112. for f in self.files_to_clean:
  113. os.unlink(f)
  114. for d in self.dirs_to_clean[::-1]:
  115. os.rmdir(d)
  116. def match_extensions(filename: str, extensions: Iterable) -> bool:
  117. """Helper method to see if filename ends with certain extension"""
  118. return any(filename.endswith(e) for e in extensions)
  119. def _fnmatch(filepath, patterns):
  120. return any(fnmatch.fnmatch(filepath, pattern) for pattern in patterns)
  121. def matched_files_iter(
  122. root_path: str,
  123. includes: Iterable = (),
  124. ignores: Iterable = (),
  125. extensions: Iterable = (),
  126. out_of_place_only: bool = False,
  127. is_pytorch_extension: bool = False) -> Iterator[str]:
  128. exact_matches = set(includes)
  129. # This is a very rough heuristic; really, we want to avoid scanning
  130. # any file which is not checked into source control, but this script
  131. # needs to work even if you're in a Git or Hg checkout, so easier to
  132. # just block the biggest time sinks that won't matter in the
  133. # end.
  134. for (abs_dirpath, dirs, filenames) in os.walk(root_path, topdown=True):
  135. rel_dirpath = os.path.relpath(abs_dirpath, root_path)
  136. if rel_dirpath == '.':
  137. # Blah blah blah O(n) blah blah
  138. if ".git" in dirs:
  139. dirs.remove(".git")
  140. if "build" in dirs:
  141. dirs.remove("build")
  142. if "third_party" in dirs:
  143. dirs.remove("third_party")
  144. dirs.append("third_party/nvfuser")
  145. for filename in filenames:
  146. filepath = os.path.join(abs_dirpath, filename)
  147. rel_filepath = os.path.join(rel_dirpath, filename)
  148. # We respect extensions, UNLESS you wrote the entire
  149. # filename verbatim, in which case we always accept it
  150. if (
  151. _fnmatch(filepath, includes)
  152. and (not _fnmatch(filepath, ignores))
  153. and (match_extensions(filepath, extensions) or filepath in exact_matches)
  154. ):
  155. if not is_pytorch_extension: # for pytorch extensions, consider all files
  156. if not is_pytorch_file(rel_filepath) and not is_caffe2_gpu_file(rel_filepath):
  157. continue
  158. if out_of_place_only and not is_out_of_place(rel_filepath):
  159. continue
  160. yield filepath
  161. def preprocess_file_and_save_result(
  162. output_directory: str,
  163. filepath: str,
  164. all_files: Iterable,
  165. header_include_dirs: Iterable,
  166. stats: Dict[str, List],
  167. hip_clang_launch: bool,
  168. is_pytorch_extension: bool,
  169. clean_ctx: GeneratedFileCleaner,
  170. show_progress: bool) -> None:
  171. fin_path = os.path.abspath(os.path.join(output_directory, filepath))
  172. hipify_result = HipifyResult(current_state=CurrentState.INITIALIZED, hipified_path=fin_path)
  173. HIPIFY_FINAL_RESULT[fin_path] = hipify_result
  174. result = preprocessor(output_directory, filepath, all_files, header_include_dirs, stats,
  175. hip_clang_launch, is_pytorch_extension, clean_ctx, show_progress)
  176. # Show what happened
  177. if show_progress and "ignored" not in result.status:
  178. print(
  179. fin_path, "->",
  180. result.hipified_path, result.status, flush=True)
  181. HIPIFY_FINAL_RESULT[fin_path] = result
  182. def compute_stats(stats):
  183. unsupported_calls = {cuda_call for (cuda_call, _filepath) in stats["unsupported_calls"]}
  184. # Print the number of unsupported calls
  185. print(f"Total number of unsupported CUDA function calls: {len(unsupported_calls):d}")
  186. # Print the list of unsupported calls
  187. print(", ".join(unsupported_calls))
  188. # Print the number of kernel launches
  189. print(f"\nTotal number of replaced kernel launches: {len(stats['kernel_launches']):d}")
  190. def add_dim3(kernel_string, cuda_kernel):
  191. '''adds dim3() to the second and third arguments in the kernel launch'''
  192. count = 0
  193. closure = 0
  194. kernel_string = kernel_string.replace("<<<", "").replace(">>>", "")
  195. arg_locs: List[Dict[str, int]] = [{} for _ in range(2)]
  196. arg_locs[count]['start'] = 0
  197. for ind, c in enumerate(kernel_string):
  198. if count > 1:
  199. break
  200. if c == "(":
  201. closure += 1
  202. elif c == ")":
  203. closure -= 1
  204. if (c == "," or ind == len(kernel_string) - 1) and closure == 0:
  205. arg_locs[count]['end'] = ind + (c != ",")
  206. count += 1
  207. if count < 2:
  208. arg_locs[count]['start'] = ind + 1
  209. first_arg_raw = kernel_string[arg_locs[0]['start']:arg_locs[0]['end'] + 1]
  210. second_arg_raw = kernel_string[arg_locs[1]['start']:arg_locs[1]['end']]
  211. first_arg_clean = kernel_string[arg_locs[0]['start']:arg_locs[0]['end']].replace("\n", "").strip(" ")
  212. second_arg_clean = kernel_string[arg_locs[1]['start']:arg_locs[1]['end']].replace("\n", "").strip(" ")
  213. first_arg_dim3 = f"dim3({first_arg_clean})"
  214. second_arg_dim3 = f"dim3({second_arg_clean})"
  215. first_arg_raw_dim3 = first_arg_raw.replace(first_arg_clean, first_arg_dim3)
  216. second_arg_raw_dim3 = second_arg_raw.replace(second_arg_clean, second_arg_dim3)
  217. cuda_kernel = cuda_kernel.replace(first_arg_raw + second_arg_raw, first_arg_raw_dim3 + second_arg_raw_dim3)
  218. return cuda_kernel
  219. RE_KERNEL_LAUNCH = re.compile(r'([ ]+)(detail?)::[ ]+\\\n[ ]+')
  220. def processKernelLaunches(string, stats):
  221. """ Replace the CUDA style Kernel launches with the HIP style kernel launches."""
  222. # Concat the namespace with the kernel names. (Find cleaner way of doing this later).
  223. string = RE_KERNEL_LAUNCH.sub(lambda inp: f"{inp.group(1)}{inp.group(2)}::", string)
  224. def grab_method_and_template(in_kernel):
  225. # The positions for relevant kernel components.
  226. pos = {
  227. "kernel_launch": {"start": in_kernel["start"], "end": in_kernel["end"]},
  228. "kernel_name": {"start": -1, "end": -1},
  229. "template": {"start": -1, "end": -1}
  230. }
  231. # Count for balancing template
  232. count = {"<>": 0}
  233. # Status for whether we are parsing a certain item.
  234. START = 0
  235. AT_TEMPLATE = 1
  236. AFTER_TEMPLATE = 2
  237. AT_KERNEL_NAME = 3
  238. status = START
  239. # Parse the string character by character
  240. for i in range(pos["kernel_launch"]["start"] - 1, -1, -1):
  241. char = string[i]
  242. # Handle Templating Arguments
  243. if status in (START, AT_TEMPLATE):
  244. if char == ">":
  245. if status == START:
  246. status = AT_TEMPLATE
  247. pos["template"]["end"] = i
  248. count["<>"] += 1
  249. if char == "<":
  250. count["<>"] -= 1
  251. if count["<>"] == 0 and (status == AT_TEMPLATE):
  252. pos["template"]["start"] = i
  253. status = AFTER_TEMPLATE
  254. # Handle Kernel Name
  255. if status != AT_TEMPLATE:
  256. if string[i].isalnum() or string[i] in {'(', ')', '_', ':', '#'}:
  257. if status != AT_KERNEL_NAME:
  258. status = AT_KERNEL_NAME
  259. pos["kernel_name"]["end"] = i
  260. # Case: Kernel name starts the string.
  261. if i == 0:
  262. pos["kernel_name"]["start"] = 0
  263. # Finished
  264. return [(pos["kernel_name"]), (pos["template"]), (pos["kernel_launch"])]
  265. else:
  266. # Potential ending point if we're already traversing a kernel's name.
  267. if status == AT_KERNEL_NAME:
  268. pos["kernel_name"]["start"] = i
  269. # Finished
  270. return [(pos["kernel_name"]), (pos["template"]), (pos["kernel_launch"])]
  271. def find_kernel_bounds(string):
  272. """Finds the starting and ending points for all kernel launches in the string."""
  273. kernel_end = 0
  274. kernel_positions = []
  275. # Continue until we cannot find any more kernels anymore.
  276. while string.find("<<<", kernel_end) != -1:
  277. # Get kernel starting position (starting from the previous ending point)
  278. kernel_start = string.find("<<<", kernel_end)
  279. # Get kernel ending position (adjust end point past the >>>)
  280. kernel_end = string.find(">>>", kernel_start) + 3
  281. if kernel_end <= 0:
  282. raise InputError("no kernel end found")
  283. # Add to list of traversed kernels
  284. kernel_positions.append({"start": kernel_start, "end": kernel_end,
  285. "group": string[kernel_start: kernel_end]})
  286. return kernel_positions
  287. # Replace comments and string literals from the code so that find_kernel_bounds does not
  288. # wrongly capture kernels in comments and string literals.
  289. # This function replaces them with "x" to keep positions.
  290. def mask_comments(string):
  291. in_comment = ''
  292. prev_c = ''
  293. new_string = ''
  294. for c in string:
  295. if in_comment == '':
  296. # Outside comments
  297. if c == '/' and prev_c == '/':
  298. in_comment = '//'
  299. elif c == '*' and prev_c == '/':
  300. in_comment = '/*'
  301. elif c == '"' and prev_c != '\\' and prev_c != "'":
  302. in_comment = '"'
  303. elif in_comment == '//':
  304. # In // xxx
  305. if c == '\r' or c == '\n':
  306. in_comment = ''
  307. elif in_comment == '/*':
  308. # In /* xxx */
  309. if c == '/' and prev_c == '*':
  310. in_comment = ''
  311. elif in_comment == '"':
  312. # In ""
  313. if c == '"' and prev_c != '\\':
  314. in_comment = ''
  315. prev_c = c
  316. if in_comment == '':
  317. new_string += c
  318. else:
  319. new_string += 'x'
  320. return new_string
  321. # Grab positional ranges of all kernel launches
  322. get_kernel_positions = list(find_kernel_bounds(mask_comments(string)))
  323. output_string = string
  324. # Replace each CUDA kernel with a HIP kernel.
  325. for kernel in get_kernel_positions:
  326. # Get kernel components
  327. params = grab_method_and_template(kernel)
  328. # Find parenthesis after kernel launch
  329. parenthesis = string.find("(", kernel["end"])
  330. # Extract cuda kernel
  331. cuda_kernel = string[params[0]["start"]:parenthesis + 1]
  332. kernel_string = string[kernel['start']:kernel['end']]
  333. end_param_index = 0 if params[1]['end'] == -1 else 1
  334. kernel_name_with_template = string[params[0]['start']:params[end_param_index]['end'] + 1]
  335. cuda_kernel_dim3 = add_dim3(kernel_string, cuda_kernel)
  336. # Keep number of kernel launch params consistent (grid dims, group dims, stream, dynamic shared size)
  337. num_klp = len(extract_arguments(0, kernel["group"].replace("<<<", "(").replace(">>>", ")")))
  338. hip_kernel = "hipLaunchKernelGGL(" + cuda_kernel_dim3[0:-1].replace(
  339. ">>>", ", 0" * (4 - num_klp) + ">>>").replace("<<<", ", ").replace(
  340. ">>>", ", ").replace(kernel_name_with_template, "(" + kernel_name_with_template + ")")
  341. # Replace cuda kernel with hip kernel
  342. output_string = output_string.replace(cuda_kernel, hip_kernel)
  343. # Update the statistics
  344. stats["kernel_launches"].append(hip_kernel)
  345. return output_string
  346. def find_closure_group(input_string, start, group):
  347. """Generalization for finding a balancing closure group
  348. if group = ["(", ")"], then finds the first balanced parentheses.
  349. if group = ["{", "}"], then finds the first balanced bracket.
  350. Given an input string, a starting position in the input string, and the group type,
  351. find_closure_group returns the positions of group[0] and group[1] as a tuple.
  352. Example:
  353. >>> find_closure_group("(hi)", 0, ["(", ")"])
  354. (0, 3)
  355. """
  356. inside_parenthesis = False
  357. parens = 0
  358. pos = start
  359. p_start, p_end = -1, -1
  360. while pos < len(input_string):
  361. if input_string[pos] == group[0]:
  362. if inside_parenthesis is False:
  363. inside_parenthesis = True
  364. parens = 1
  365. p_start = pos
  366. else:
  367. parens += 1
  368. elif input_string[pos] == group[1] and inside_parenthesis:
  369. parens -= 1
  370. if parens == 0:
  371. p_end = pos
  372. return p_start, p_end
  373. pos += 1
  374. return None, None
  375. def find_bracket_group(input_string, start):
  376. """Finds the first balanced parantheses."""
  377. return find_closure_group(input_string, start, group=["{", "}"])
  378. def find_parentheses_group(input_string, start):
  379. """Finds the first balanced bracket."""
  380. return find_closure_group(input_string, start, group=["(", ")"])
  381. RE_ASSERT = re.compile(r"\bassert[ ]*\(")
  382. def replace_math_functions(input_string):
  383. """FIXME: Temporarily replace std:: invocations of math functions
  384. with non-std:: versions to prevent linker errors NOTE: This
  385. can lead to correctness issues when running tests, since the
  386. correct version of the math function (exp/expf) might not get
  387. called. Plan is to remove this function once HIP supports
  388. std:: math function calls inside device code
  389. """
  390. output_string = input_string
  391. for func in MATH_TRANSPILATIONS:
  392. output_string = output_string.replace(fr'{func}(', f'{MATH_TRANSPILATIONS[func]}(')
  393. return output_string
  394. RE_SYNCTHREADS = re.compile(r":?:?\b(__syncthreads)\b(\w*\()")
  395. def hip_header_magic(input_string):
  396. """If the file makes kernel builtin calls and does not include the cuda_runtime.h header,
  397. then automatically add an #include to match the "magic" includes provided by NVCC.
  398. TODO:
  399. Update logic to ignore cases where the cuda_runtime.h is included by another file.
  400. """
  401. # Copy the input.
  402. output_string = input_string
  403. # Check if one of the following headers is already included.
  404. headers = ["hip/hip_runtime.h", "hip/hip_runtime_api.h"]
  405. if any(re.search(fr'#include ("{ext}"|<{ext}>)', output_string) for ext in headers):
  406. return output_string
  407. # Rough logic to detect if we're inside device code
  408. hasDeviceLogic: int
  409. hasDeviceLogic = "hipLaunchKernelGGL" in output_string
  410. hasDeviceLogic += "__global__" in output_string
  411. hasDeviceLogic += "__shared__" in output_string
  412. hasDeviceLogic += RE_SYNCTHREADS.search(output_string) is not None
  413. # If device logic found, provide the necessary header.
  414. if hasDeviceLogic:
  415. output_string = '#include "hip/hip_runtime.h"\n' + input_string
  416. return output_string
  417. RE_EXTERN_SHARED = re.compile(r"extern\s+([\w\(\)]+)?\s*__shared__\s+([\w:<>\s]+)\s+(\w+)\s*\[\s*\]\s*;")
  418. def replace_extern_shared(input_string):
  419. """Match extern __shared__ type foo[]; syntax and use HIP_DYNAMIC_SHARED() MACRO instead.
  420. https://github.com/ROCm-Developer-Tools/HIP/blob/master/docs/markdown/hip_kernel_language.md#__shared__
  421. Example:
  422. "extern __shared__ char smemChar[];" => "HIP_DYNAMIC_SHARED( char, smemChar)"
  423. "extern __shared__ unsigned char smem[];" => "HIP_DYNAMIC_SHARED( unsigned char, my_smem)"
  424. """
  425. output_string = input_string
  426. output_string = RE_EXTERN_SHARED.sub(
  427. lambda inp: f"HIP_DYNAMIC_SHARED({inp.group(1) or ''} {inp.group(2)}, {inp.group(3)})", output_string)
  428. return output_string
  429. def get_hip_file_path(rel_filepath, is_pytorch_extension=False):
  430. """
  431. Returns the new name of the hipified file
  432. """
  433. # At the moment, some PyTorch source files are HIPified in place. The predicate
  434. # is_out_of_place tells us if this is the case or not.
  435. assert not os.path.isabs(rel_filepath)
  436. if not is_pytorch_extension and not is_out_of_place(rel_filepath):
  437. return rel_filepath
  438. dirpath, filename = os.path.split(rel_filepath)
  439. root, ext = os.path.splitext(filename)
  440. # Here's the plan:
  441. #
  442. # In general, we need to disambiguate the HIPified filename so that
  443. # it gets a different name from the original filename, so
  444. # that we don't overwrite the original file
  445. #
  446. # There's a lot of different naming conventions across PyTorch
  447. # and Caffe2, but the general recipe is to convert occurrences
  448. # of cuda/gpu to hip, and add hip if there are no occurrences
  449. # of cuda/gpu anywhere.
  450. #
  451. # Concretely, we do the following:
  452. #
  453. # - If there is a directory component named "cuda", replace
  454. # it with "hip", AND
  455. #
  456. # - If the file name contains "CUDA", replace it with "HIP", AND
  457. #
  458. # - ALWAYS replace '.cu' with '.hip', because those files
  459. # contain CUDA kernels that needs to be hipified and processed with
  460. # hip compiler
  461. #
  462. # - If we are not hipifying a PyTorch extension, and the parent
  463. # directory name did not change as a result of the above
  464. # transformations, insert "hip" in the file path
  465. # as the direct parent folder of the file
  466. #
  467. # - If we are hipifying a PyTorch extension, and the parent directory
  468. # name as well as the filename (incl. extension) did not change as
  469. # a result of the above transformations, insert "_hip" in the filename
  470. #
  471. # This isn't set in stone; we might adjust this to support other
  472. # naming conventions.
  473. if ext == '.cu':
  474. ext = '.hip'
  475. orig_filename = filename
  476. orig_dirpath = dirpath
  477. dirpath = dirpath.replace('cuda', 'hip')
  478. dirpath = dirpath.replace('CUDA', 'HIP')
  479. dirpath = dirpath.replace('THC', 'THH')
  480. root = root.replace('cuda', 'hip')
  481. root = root.replace('CUDA', 'HIP')
  482. # Special case to handle caffe2/core/THCCachingAllocator
  483. if dirpath != "caffe2/core":
  484. root = root.replace('THC', 'THH')
  485. if not is_pytorch_extension and dirpath == orig_dirpath:
  486. dirpath = os.path.join(dirpath, 'hip')
  487. if is_pytorch_extension and dirpath == orig_dirpath and (root + ext) == orig_filename:
  488. root = root + "_hip"
  489. return os.path.join(dirpath, root + ext)
  490. def is_out_of_place(rel_filepath):
  491. assert not os.path.isabs(rel_filepath)
  492. if rel_filepath.startswith("torch/"):
  493. return False
  494. if rel_filepath.startswith("third_party/nvfuser/"):
  495. return False
  496. if rel_filepath.startswith("tools/autograd/templates/"):
  497. return False
  498. return True
  499. # Keep this synchronized with includes/ignores in build_amd.py
  500. def is_pytorch_file(rel_filepath):
  501. assert not os.path.isabs(rel_filepath)
  502. if rel_filepath.startswith("aten/"):
  503. if rel_filepath.startswith("aten/src/ATen/core/"):
  504. return False
  505. return True
  506. if rel_filepath.startswith("torch/"):
  507. return True
  508. if rel_filepath.startswith("third_party/nvfuser/"):
  509. return True
  510. if rel_filepath.startswith("tools/autograd/templates/"):
  511. return True
  512. return False
  513. def is_cusparse_file(rel_filepath):
  514. if is_pytorch_file(rel_filepath):
  515. return "sparse" in rel_filepath.lower()
  516. return False
  517. def is_special_file(rel_filepath):
  518. if is_pytorch_file(rel_filepath):
  519. if "sparse" in rel_filepath.lower():
  520. return True
  521. elif "linalg" in rel_filepath.lower():
  522. if "batchlinearalgebralibblas" in rel_filepath.lower():
  523. return False # don't use "special" mappings for this specific linalg cublas file
  524. return True
  525. return False
  526. def is_caffe2_gpu_file(rel_filepath):
  527. assert not os.path.isabs(rel_filepath)
  528. if rel_filepath.startswith("c10/cuda"):
  529. return True
  530. filename = os.path.basename(rel_filepath)
  531. _, ext = os.path.splitext(filename)
  532. return ('gpu' in filename or ext in ['.cu', '.cuh']) and ('cudnn' not in filename)
  533. class TrieNode:
  534. """A Trie node whose children are represented as a directory of char: TrieNode.
  535. A special char '' represents end of word
  536. """
  537. def __init__(self):
  538. self.children = {}
  539. class Trie:
  540. """Creates a Trie out of a list of words. The trie can be exported to a Regex pattern.
  541. The corresponding Regex should match much faster than a simple Regex union."""
  542. def __init__(self):
  543. """Initialize the trie with an empty root node."""
  544. self.root = TrieNode()
  545. def add(self, word):
  546. """Add a word to the Trie. """
  547. node = self.root
  548. for char in word:
  549. node.children.setdefault(char, TrieNode())
  550. node = node.children[char]
  551. node.children[''] = True # Mark the end of the word
  552. def dump(self):
  553. """Return the root node of Trie. """
  554. return self.root
  555. def quote(self, char):
  556. """ Escape a char for regex. """
  557. return re.escape(char)
  558. def search(self, word):
  559. """Search whether word is present in the Trie.
  560. Returns True if yes, else return False"""
  561. node = self.root
  562. for char in word:
  563. if char in node.children:
  564. node = node.children[char]
  565. else:
  566. return False
  567. # make sure to check the end-of-word marker present
  568. return '' in node.children
  569. def _pattern(self, root):
  570. """Convert a Trie into a regular expression pattern"""
  571. node = root
  572. if "" in node.children and len(node.children.keys()) == 1:
  573. return None
  574. alt = [] # store alternative patterns
  575. cc = [] # store char to char classes
  576. q = 0 # for node representing the end of word
  577. for char in sorted(node.children.keys()):
  578. if isinstance(node.children[char], TrieNode):
  579. try:
  580. recurse = self._pattern(node.children[char])
  581. alt.append(self.quote(char) + recurse)
  582. except Exception:
  583. cc.append(self.quote(char))
  584. else:
  585. q = 1
  586. cconly = not len(alt) > 0
  587. if len(cc) > 0:
  588. if len(cc) == 1:
  589. alt.append(cc[0])
  590. else:
  591. alt.append('[' + ''.join(cc) + ']')
  592. if len(alt) == 1:
  593. result = alt[0]
  594. else:
  595. result = "(?:" + "|".join(alt) + ")"
  596. if q:
  597. if cconly:
  598. result += "?"
  599. else:
  600. result = f"(?:{result})?"
  601. return result
  602. def pattern(self):
  603. """Export the Trie to a regex pattern."""
  604. return self._pattern(self.root)
  605. def export_to_regex(self):
  606. """Export the Trie to a regex pattern."""
  607. return self._pattern(self.root)
  608. CAFFE2_TRIE = Trie()
  609. CAFFE2_MAP = {}
  610. PYTORCH_TRIE = Trie()
  611. PYTORCH_MAP: Dict[str, object] = {}
  612. # In PyTorch, we map cuBLAS->rocBLAS and cuSPARSE->hipSPARSE. Note the prefix, roc versus hip.
  613. # The 'hip' APIs offer a more direct CUDA-friendly mapping, but calling rocBLAS directly has better performance.
  614. # Unfortunately, the roc* types and hip* types differ, i.e., rocblas_float_complex versus hipComplex.
  615. # In the case of SPARSE, we must use the hip types for complex instead of the roc types,
  616. # but the pytorch mappings assume roc. Therefore, we create a new SPARSE mapping that has a higher priority.
  617. # Its mappings will trigger first, and only when a miss occurs will the lower-priority pytorch mapping take place.
  618. # When a file contains "sparse" in the filename, a mapping marked with API_SPARSE is preferred over other choices.
  619. # Similarly, "linalg" files require rocBLAS -> hipSOLVER so they also need special handling.
  620. PYTORCH_SPECIAL_MAP = {}
  621. for mapping in CUDA_TO_HIP_MAPPINGS:
  622. assert isinstance(mapping, Mapping)
  623. for src, value in mapping.items():
  624. dst = value[0]
  625. meta_data = value[1:]
  626. if constants.API_CAFFE2 not in meta_data:
  627. PYTORCH_TRIE.add(src)
  628. # if src is already in PYTORCH_MAP and dst belongs to API_SPECIAL
  629. # do not overwrite PYTORCH_MAP, store dst separately
  630. if constants.API_SPECIAL in meta_data and PYTORCH_MAP.get(src, ""):
  631. PYTORCH_SPECIAL_MAP[src] = dst
  632. else:
  633. PYTORCH_MAP[src] = dst
  634. if constants.API_PYTORCH not in meta_data and constants.API_SPECIAL not in meta_data:
  635. CAFFE2_TRIE.add(src)
  636. CAFFE2_MAP[src] = dst
  637. RE_CAFFE2_PREPROCESSOR = re.compile(CAFFE2_TRIE.export_to_regex())
  638. RE_PYTORCH_PREPROCESSOR = re.compile(fr'(?<=\W)({PYTORCH_TRIE.export_to_regex()})(?=\W)')
  639. RE_QUOTE_HEADER = re.compile(r'#include "([^"]+)"')
  640. RE_ANGLE_HEADER = re.compile(r'#include <([^>]+)>')
  641. RE_THC_GENERIC_FILE = re.compile(r'#define THC_GENERIC_FILE "([^"]+)"')
  642. RE_CU_SUFFIX = re.compile(r'\.cu\b') # be careful not to pick up .cuh
  643. """
  644. Returns a HipifyResult object with the following details:
  645. "hipified_path" : absolute path of hipified source file
  646. "status" : "ok" if hipified file was written out
  647. "skipped" if an identical hipified file already existed or hipified file couldn't be written out
  648. "ignored" if the source file was a hipified file itself or not meant to be hipified
  649. "current_state" : CurrentState.INITIALIZED if source file is first ready to be hipified
  650. CurrentState.DONE if source file is done with hipification process
  651. """
  652. def preprocessor(
  653. output_directory: str,
  654. filepath: str,
  655. all_files: Iterable,
  656. header_include_dirs: Iterable,
  657. stats: Dict[str, List],
  658. hip_clang_launch: bool,
  659. is_pytorch_extension: bool,
  660. clean_ctx: GeneratedFileCleaner,
  661. show_progress: bool) -> HipifyResult:
  662. """ Executes the CUDA -> HIP conversion on the specified file. """
  663. fin_path = os.path.abspath(os.path.join(output_directory, filepath))
  664. hipify_result = HIPIFY_FINAL_RESULT[fin_path]
  665. if filepath not in all_files:
  666. hipify_result.hipified_path = None
  667. hipify_result.status = "[ignored, not to be hipified]"
  668. hipify_result.current_state = CurrentState.DONE
  669. return hipify_result
  670. rel_filepath = os.path.relpath(filepath, output_directory)
  671. with open(fin_path, encoding='utf-8') as fin:
  672. if fin.readline() == HIPIFY_C_BREADCRUMB:
  673. hipify_result.hipified_path = None
  674. hipify_result.status = "[ignored, input is hipified output]"
  675. hipify_result.current_state = CurrentState.DONE
  676. return hipify_result
  677. fin.seek(0)
  678. output_source = fin.read()
  679. orig_output_source = output_source
  680. # get_hip_file_path needs a relative path to work correctly
  681. fout_path = os.path.abspath(os.path.join(output_directory, get_hip_file_path(rel_filepath, is_pytorch_extension)))
  682. if not os.path.exists(os.path.dirname(fout_path)):
  683. clean_ctx.makedirs(os.path.dirname(fout_path))
  684. # unsupported_calls statistics reporting is broken atm
  685. def pt_repl(m):
  686. return PYTORCH_MAP[m.group(0)]
  687. def pt_special_repl(m):
  688. # checks SPECIAL map first, and if a miss occurs, falls back to pytorch mappings
  689. return PYTORCH_SPECIAL_MAP.get(m.group(0), pt_repl(m))
  690. if is_pytorch_extension:
  691. output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source)
  692. else:
  693. if is_special_file(rel_filepath):
  694. output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_special_repl, output_source)
  695. elif is_pytorch_file(rel_filepath):
  696. output_source = RE_PYTORCH_PREPROCESSOR.sub(pt_repl, output_source)
  697. else:
  698. def c2_repl(m):
  699. return CAFFE2_MAP[m.group(0)]
  700. output_source = RE_CAFFE2_PREPROCESSOR.sub(c2_repl, output_source)
  701. # Header rewrites
  702. def mk_repl(templ, include_current_dir=True):
  703. def repl(m):
  704. f = m.group(1)
  705. dirpath, filename = os.path.split(f)
  706. if (
  707. f.startswith(("ATen/cuda",
  708. "ATen/native/cuda",
  709. "ATen/native/nested/cuda",
  710. "ATen/native/quantized/cuda",
  711. "ATen/native/sparse/cuda",
  712. "ATen/native/transformers/cuda",
  713. "THC/")) or
  714. (f.startswith("THC") and not f.startswith("THCP"))
  715. ):
  716. return templ.format(get_hip_file_path(m.group(1), is_pytorch_extension))
  717. # if filename is one of the files being hipified for this extension
  718. if (is_pytorch_extension and any(s.endswith(filename) for s in all_files)):
  719. header_dir = None
  720. header_filepath = None
  721. # If include_current_dir True, look first in same dir as the including source file
  722. if include_current_dir:
  723. header_dir_to_check = os.path.dirname(fin_path)
  724. header_path_to_check = os.path.abspath(os.path.join(header_dir_to_check, f))
  725. if os.path.exists(header_path_to_check):
  726. header_dir = header_dir_to_check
  727. header_filepath = header_path_to_check
  728. # If not found, look in include dirs one by one and first match wins
  729. if header_filepath is None:
  730. for header_include_dir in header_include_dirs:
  731. header_dir_to_check = os.path.join(output_directory, header_include_dir)
  732. header_path_to_check = os.path.abspath(os.path.join(header_dir_to_check, f))
  733. if os.path.exists(header_path_to_check):
  734. header_dir = header_dir_to_check
  735. header_filepath = header_path_to_check
  736. # If header file not found, keep as is
  737. if header_filepath is None:
  738. return m.group(0)
  739. # Hipify header file first if needed
  740. if header_filepath not in HIPIFY_FINAL_RESULT:
  741. preprocess_file_and_save_result(output_directory,
  742. header_filepath,
  743. all_files, header_include_dirs, stats, hip_clang_launch,
  744. is_pytorch_extension, clean_ctx, show_progress)
  745. elif header_filepath in HIPIFY_FINAL_RESULT:
  746. header_result = HIPIFY_FINAL_RESULT[header_filepath]
  747. if header_result.current_state == CurrentState.INITIALIZED:
  748. # get_hip_file_path needs a relative path to work correctly
  749. header_rel_path = os.path.relpath(header_filepath, output_directory)
  750. header_fout_path = os.path.abspath(os.path.join(output_directory,
  751. get_hip_file_path(header_rel_path, is_pytorch_extension)))
  752. header_result.hipified_path = header_fout_path
  753. HIPIFY_FINAL_RESULT[header_filepath] = header_result
  754. return templ.format(os.path.relpath(header_fout_path if header_fout_path is not None
  755. else header_filepath, header_dir))
  756. hipified_header_filepath = HIPIFY_FINAL_RESULT[header_filepath].hipified_path
  757. return templ.format(os.path.relpath(hipified_header_filepath if hipified_header_filepath is not None
  758. else header_filepath, header_dir))
  759. return m.group(0)
  760. return repl
  761. output_source = RE_QUOTE_HEADER.sub(mk_repl('#include "{0}"', True), output_source)
  762. output_source = RE_ANGLE_HEADER.sub(mk_repl('#include <{0}>', False), output_source)
  763. output_source = RE_THC_GENERIC_FILE.sub(mk_repl('#define THC_GENERIC_FILE "{0}"'), output_source)
  764. # CMakeLists.txt rewrites
  765. if filepath.endswith('CMakeLists.txt'):
  766. output_source = output_source.replace('CUDA', 'HIP')
  767. output_source = output_source.replace('THC', 'THH')
  768. output_source = RE_CU_SUFFIX.sub('.hip', output_source)
  769. # Perform Kernel Launch Replacements
  770. if not hip_clang_launch:
  771. output_source = processKernelLaunches(output_source, stats)
  772. # Replace std:: with non-std:: versions
  773. if (filepath.endswith((".cu", ".cuh"))) and "PowKernel" not in filepath:
  774. output_source = replace_math_functions(output_source)
  775. # Include header if device code is contained.
  776. output_source = hip_header_magic(output_source)
  777. # Replace the extern __shared__
  778. # NOTE: No longer needed after transition from hcc to hipclang.
  779. # output_source = replace_extern_shared(output_source)
  780. # Don't write out identical hipified files for extensions if dirpath has not changed
  781. if (
  782. is_pytorch_extension
  783. and orig_output_source == output_source
  784. and os.path.dirname(fin_path) == os.path.dirname(fout_path)
  785. ):
  786. hipify_result.hipified_path = fin_path
  787. hipify_result.status = "[skipped, no changes]"
  788. hipify_result.current_state = CurrentState.DONE
  789. return hipify_result
  790. # Add hipify breadcrumb for C-style files to avoid re-hipification
  791. if fin_path != fout_path and match_extensions(fin_path, (".cu", ".cuh", ".c", ".cc", ".cpp", ".h", ".hpp")):
  792. output_source = HIPIFY_C_BREADCRUMB + output_source
  793. do_write = True
  794. if os.path.exists(fout_path):
  795. with open(fout_path, encoding='utf-8') as fout_old:
  796. do_write = fout_old.read() != output_source
  797. if do_write:
  798. try:
  799. with clean_ctx.open(fout_path, 'w', encoding='utf-8') as fout:
  800. fout.write(output_source)
  801. hipify_result.hipified_path = fout_path
  802. hipify_result.status = "[ok]"
  803. hipify_result.current_state = CurrentState.DONE
  804. return hipify_result
  805. except PermissionError as e:
  806. print(f'{bcolors.WARNING}Failed to save {fout_path} with "{e.strerror}", leaving {fin_path} unchanged.{bcolors.ENDC}',
  807. file=sys.stderr)
  808. hipify_result.hipified_path = fin_path
  809. hipify_result.status = "[skipped, no permissions]"
  810. hipify_result.current_state = CurrentState.DONE
  811. return hipify_result
  812. else:
  813. hipify_result.hipified_path = fout_path
  814. hipify_result.status = "[skipped, already hipified]"
  815. hipify_result.current_state = CurrentState.DONE
  816. return hipify_result
  817. def file_specific_replacement(filepath, search_string, replace_string, strict=False):
  818. with openf(filepath, "r+") as f:
  819. contents = f.read()
  820. if strict:
  821. contents = re.sub(fr'\b({re.escape(search_string)})\b', lambda x: replace_string, contents)
  822. else:
  823. contents = contents.replace(search_string, replace_string)
  824. f.seek(0)
  825. f.write(contents)
  826. f.truncate()
  827. def file_add_header(filepath, header):
  828. with openf(filepath, "r+") as f:
  829. contents = f.read()
  830. if header[0] != "<" and header[-1] != ">":
  831. header = f'"{header}"'
  832. contents = (f'#include {header} \n') + contents
  833. f.seek(0)
  834. f.write(contents)
  835. f.truncate()
  836. def fix_static_global_kernels(in_txt):
  837. """Static global kernels in HIP results in a compilation error."""
  838. in_txt = in_txt.replace(" __global__ static", "__global__")
  839. return in_txt
  840. RE_INCLUDE = re.compile(r"#include .*\n")
  841. def extract_arguments(start, string):
  842. """ Return the list of arguments in the upcoming function parameter closure.
  843. Example:
  844. string (input): '(blocks, threads, 0, THCState_getCurrentStream(state))'
  845. arguments (output):
  846. '[{'start': 1, 'end': 7},
  847. {'start': 8, 'end': 16},
  848. {'start': 17, 'end': 19},
  849. {'start': 20, 'end': 53}]'
  850. """
  851. arguments = []
  852. closures = {
  853. "<": 0,
  854. "(": 0
  855. }
  856. current_position = start
  857. argument_start_pos = current_position + 1
  858. # Search for final parenthesis
  859. while current_position < len(string):
  860. if string[current_position] == "(":
  861. closures["("] += 1
  862. elif string[current_position] == ")":
  863. closures["("] -= 1
  864. elif string[current_position] == "<":
  865. closures["<"] += 1
  866. elif string[current_position] == ">" and string[current_position - 1] != "-" and closures["<"] > 0:
  867. closures["<"] -= 1
  868. # Finished all arguments
  869. if closures["("] == 0 and closures["<"] == 0:
  870. # Add final argument
  871. arguments.append({"start": argument_start_pos, "end": current_position})
  872. break
  873. # Finished current argument
  874. if closures["("] == 1 and closures["<"] == 0 and string[current_position] == ",":
  875. arguments.append({"start": argument_start_pos, "end": current_position})
  876. argument_start_pos = current_position + 1
  877. current_position += 1
  878. return arguments
  879. def str2bool(v):
  880. """ArgumentParser doesn't support type=bool. Thus, this helper method will convert
  881. from possible string types to True / False."""
  882. if v.lower() in ('yes', 'true', 't', 'y', '1'):
  883. return True
  884. elif v.lower() in ('no', 'false', 'f', 'n', '0'):
  885. return False
  886. else:
  887. raise argparse.ArgumentTypeError('Boolean value expected.')
  888. def hipify(
  889. project_directory: str,
  890. show_detailed: bool = False,
  891. extensions: Iterable = (".cu", ".cuh", ".c", ".cc", ".cpp", ".h", ".in", ".hpp"),
  892. header_extensions: Iterable = (".cuh", ".h", ".hpp"),
  893. output_directory: str = "",
  894. header_include_dirs: Iterable = (),
  895. includes: Iterable = ('*',),
  896. extra_files: Iterable = (),
  897. out_of_place_only: bool = False,
  898. ignores: Iterable = (),
  899. show_progress: bool = True,
  900. hip_clang_launch: bool = False,
  901. is_pytorch_extension: bool = False,
  902. hipify_extra_files_only: bool = False,
  903. clean_ctx: Optional[GeneratedFileCleaner] = None
  904. ) -> HipifyFinalResult:
  905. if project_directory == "":
  906. project_directory = os.getcwd()
  907. # Verify the project directory exists.
  908. if not os.path.exists(project_directory):
  909. print("The project folder specified does not exist.")
  910. sys.exit(1)
  911. # If no output directory, provide a default one.
  912. if not output_directory:
  913. project_directory.rstrip("/")
  914. output_directory = project_directory + "_amd"
  915. if project_directory != output_directory:
  916. includes = [include.replace(project_directory, output_directory) for include in includes]
  917. ignores = [ignore.replace(project_directory, output_directory) for ignore in ignores]
  918. # Copy from project directory to output directory if not done already.
  919. if not os.path.exists(output_directory):
  920. shutil.copytree(project_directory, output_directory)
  921. all_files = list(matched_files_iter(output_directory, includes=includes,
  922. ignores=ignores, extensions=extensions,
  923. out_of_place_only=out_of_place_only,
  924. is_pytorch_extension=is_pytorch_extension))
  925. all_files_set = set(all_files)
  926. for f in extra_files:
  927. if not os.path.isabs(f):
  928. f = os.path.join(output_directory, f)
  929. if f not in all_files_set:
  930. all_files.append(f)
  931. # List all files in header_include_paths to ensure they are hipified
  932. from pathlib import Path
  933. for header_include_dir in header_include_dirs:
  934. if os.path.isabs(header_include_dir):
  935. header_include_dir_path = Path(header_include_dir)
  936. else:
  937. header_include_dir_path = Path(os.path.join(output_directory, header_include_dir))
  938. for path in header_include_dir_path.rglob('*'):
  939. if (
  940. path.is_file()
  941. and _fnmatch(str(path), includes)
  942. and (not _fnmatch(str(path), ignores))
  943. and match_extensions(path.name, header_extensions)
  944. ):
  945. all_files.append(str(path))
  946. if clean_ctx is None:
  947. clean_ctx = GeneratedFileCleaner(keep_intermediates=True)
  948. # Preprocessing statistics.
  949. stats: Dict[str, List] = {"unsupported_calls": [], "kernel_launches": []}
  950. for filepath in (all_files if not hipify_extra_files_only else extra_files):
  951. preprocess_file_and_save_result(output_directory, filepath, all_files, header_include_dirs,
  952. stats, hip_clang_launch, is_pytorch_extension, clean_ctx, show_progress)
  953. print(bcolors.OKGREEN + "Successfully preprocessed all matching files." + bcolors.ENDC, file=sys.stderr)
  954. # Show detailed summary
  955. if show_detailed:
  956. compute_stats(stats)
  957. return HIPIFY_FINAL_RESULT