benchmark_utils.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913
  1. # This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
  2. # Copyright 2020 The HuggingFace Team and the AllenNLP authors. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """
  16. Utilities for working with the local dataset cache.
  17. """
  18. import copy
  19. import csv
  20. import linecache
  21. import os
  22. import platform
  23. import sys
  24. import warnings
  25. from abc import ABC, abstractmethod
  26. from collections import defaultdict, namedtuple
  27. from datetime import datetime
  28. from multiprocessing import Pipe, Process, Queue
  29. from multiprocessing.connection import Connection
  30. from typing import Callable, Iterable, List, NamedTuple, Optional, Union
  31. from .. import AutoConfig, PretrainedConfig
  32. from .. import __version__ as version
  33. from ..utils import is_psutil_available, is_py3nvml_available, is_tf_available, is_torch_available, logging
  34. from .benchmark_args_utils import BenchmarkArguments
  35. if is_torch_available():
  36. from torch.cuda import empty_cache as torch_empty_cache
  37. if is_tf_available():
  38. from tensorflow.python.eager import context as tf_context
  39. if is_psutil_available():
  40. import psutil
  41. if is_py3nvml_available():
  42. import py3nvml.py3nvml as nvml
  43. if platform.system() == "Windows":
  44. from signal import CTRL_C_EVENT as SIGKILL
  45. else:
  46. from signal import SIGKILL
  47. logger = logging.get_logger(__name__) # pylint: disable=invalid-name
  48. _is_memory_tracing_enabled = False
  49. BenchmarkOutput = namedtuple(
  50. "BenchmarkOutput",
  51. [
  52. "time_inference_result",
  53. "memory_inference_result",
  54. "time_train_result",
  55. "memory_train_result",
  56. "inference_summary",
  57. "train_summary",
  58. ],
  59. )
  60. def separate_process_wrapper_fn(func: Callable[[], None], do_multi_processing: bool) -> Callable[[], None]:
  61. """
  62. This function wraps another function into its own separated process. In order to ensure accurate memory
  63. measurements it is important that the function is executed in a separate process
  64. Args:
  65. - `func`: (`callable`): function() -> ... generic function which will be executed in its own separate process
  66. - `do_multi_processing`: (`bool`) Whether to run function on separate process or not
  67. """
  68. def multi_process_func(*args, **kwargs):
  69. # run function in an individual
  70. # process to get correct memory
  71. def wrapper_func(queue: Queue, *args):
  72. try:
  73. result = func(*args)
  74. except Exception as e:
  75. logger.error(e)
  76. print(e)
  77. result = "N/A"
  78. queue.put(result)
  79. queue = Queue()
  80. p = Process(target=wrapper_func, args=[queue] + list(args))
  81. p.start()
  82. result = queue.get()
  83. p.join()
  84. return result
  85. if do_multi_processing:
  86. logger.info(f"Function {func} is executed in its own process...")
  87. return multi_process_func
  88. else:
  89. return func
  90. def is_memory_tracing_enabled():
  91. global _is_memory_tracing_enabled
  92. return _is_memory_tracing_enabled
  93. class Frame(NamedTuple):
  94. """
  95. `Frame` is a NamedTuple used to gather the current frame state. `Frame` has the following fields:
  96. - 'filename' (string): Name of the file currently executed
  97. - 'module' (string): Name of the module currently executed
  98. - 'line_number' (int): Number of the line currently executed
  99. - 'event' (string): Event that triggered the tracing (default will be "line")
  100. - 'line_text' (string): Text of the line in the python script
  101. """
  102. filename: str
  103. module: str
  104. line_number: int
  105. event: str
  106. line_text: str
  107. class UsedMemoryState(NamedTuple):
  108. """
  109. `UsedMemoryState` are named tuples with the following fields:
  110. - 'frame': a `Frame` namedtuple (see below) storing information on the current tracing frame (current file,
  111. location in current file)
  112. - 'cpu_memory': CPU RSS memory state *before* executing the line
  113. - 'gpu_memory': GPU used memory *before* executing the line (sum for all GPUs or for only `gpus_to_trace` if
  114. provided)
  115. """
  116. frame: Frame
  117. cpu_memory: int
  118. gpu_memory: int
  119. class Memory(NamedTuple):
  120. """
  121. `Memory` NamedTuple have a single field `bytes` and you can get a human readable str of the number of mega bytes by
  122. calling `__repr__`
  123. - `byte` (integer): number of bytes,
  124. """
  125. bytes: int
  126. def __repr__(self) -> str:
  127. return str(bytes_to_mega_bytes(self.bytes))
  128. class MemoryState(NamedTuple):
  129. """
  130. `MemoryState` are namedtuples listing frame + CPU/GPU memory with the following fields:
  131. - `frame` (`Frame`): the current frame (see above)
  132. - `cpu`: CPU memory consumed at during the current frame as a `Memory` named tuple
  133. - `gpu`: GPU memory consumed at during the current frame as a `Memory` named tuple
  134. - `cpu_gpu`: CPU + GPU memory consumed at during the current frame as a `Memory` named tuple
  135. """
  136. frame: Frame
  137. cpu: Memory
  138. gpu: Memory
  139. cpu_gpu: Memory
  140. class MemorySummary(NamedTuple):
  141. """
  142. `MemorySummary` namedtuple otherwise with the fields:
  143. - `sequential`: a list of `MemoryState` namedtuple (see below) computed from the provided `memory_trace` by
  144. subtracting the memory after executing each line from the memory before executing said line.
  145. - `cumulative`: a list of `MemoryState` namedtuple (see below) with cumulative increase in memory for each line
  146. obtained by summing repeated memory increase for a line if it's executed several times. The list is sorted
  147. from the frame with the largest memory consumption to the frame with the smallest (can be negative if memory
  148. is released)
  149. - `total`: total memory increase during the full tracing as a `Memory` named tuple (see below). Line with
  150. memory release (negative consumption) are ignored if `ignore_released_memory` is `True` (default).
  151. """
  152. sequential: List[MemoryState]
  153. cumulative: List[MemoryState]
  154. current: List[MemoryState]
  155. total: Memory
  156. MemoryTrace = List[UsedMemoryState]
  157. def measure_peak_memory_cpu(function: Callable[[], None], interval=0.5, device_idx=None) -> int:
  158. """
  159. measures peak cpu memory consumption of a given `function` running the function for at least interval seconds and
  160. at most 20 * interval seconds. This function is heavily inspired by: `memory_usage` of the package
  161. `memory_profiler`:
  162. https://github.com/pythonprofilers/memory_profiler/blob/895c4ac7a08020d66ae001e24067da6dcea42451/memory_profiler.py#L239
  163. Args:
  164. - `function`: (`callable`): function() -> ... function without any arguments to measure for which to measure
  165. the peak memory
  166. - `interval`: (`float`, `optional`, defaults to `0.5`) interval in second for which to measure the memory usage
  167. - `device_idx`: (`int`, `optional`, defaults to `None`) device id for which to measure gpu usage
  168. Returns:
  169. - `max_memory`: (`int`) consumed memory peak in Bytes
  170. """
  171. def get_cpu_memory(process_id: int) -> int:
  172. """
  173. measures current cpu memory usage of a given `process_id`
  174. Args:
  175. - `process_id`: (`int`) process_id for which to measure memory
  176. Returns
  177. - `memory`: (`int`) consumed memory in Bytes
  178. """
  179. process = psutil.Process(process_id)
  180. try:
  181. meminfo_attr = "memory_info" if hasattr(process, "memory_info") else "get_memory_info"
  182. memory = getattr(process, meminfo_attr)()[0]
  183. except psutil.AccessDenied:
  184. raise ValueError("Error with Psutil.")
  185. return memory
  186. if not is_psutil_available():
  187. logger.warning(
  188. "Psutil not installed, we won't log CPU memory usage. "
  189. "Install Psutil (pip install psutil) to use CPU memory tracing."
  190. )
  191. max_memory = "N/A"
  192. else:
  193. class MemoryMeasureProcess(Process):
  194. """
  195. `MemoryMeasureProcess` inherits from `Process` and overwrites its `run()` method. Used to measure the
  196. memory usage of a process
  197. """
  198. def __init__(self, process_id: int, child_connection: Connection, interval: float):
  199. super().__init__()
  200. self.process_id = process_id
  201. self.interval = interval
  202. self.connection = child_connection
  203. self.num_measurements = 1
  204. self.mem_usage = get_cpu_memory(self.process_id)
  205. def run(self):
  206. self.connection.send(0)
  207. stop = False
  208. while True:
  209. self.mem_usage = max(self.mem_usage, get_cpu_memory(self.process_id))
  210. self.num_measurements += 1
  211. if stop:
  212. break
  213. stop = self.connection.poll(self.interval)
  214. # send results to parent pipe
  215. self.connection.send(self.mem_usage)
  216. self.connection.send(self.num_measurements)
  217. while True:
  218. # create child, parent connection
  219. child_connection, parent_connection = Pipe()
  220. # instantiate process
  221. mem_process = MemoryMeasureProcess(os.getpid(), child_connection, interval)
  222. mem_process.start()
  223. # wait until we get memory
  224. parent_connection.recv()
  225. try:
  226. # execute function
  227. function()
  228. # start parent connection
  229. parent_connection.send(0)
  230. # receive memory and num measurements
  231. max_memory = parent_connection.recv()
  232. num_measurements = parent_connection.recv()
  233. except Exception:
  234. # kill process in a clean way
  235. parent = psutil.Process(os.getpid())
  236. for child in parent.children(recursive=True):
  237. os.kill(child.pid, SIGKILL)
  238. mem_process.join(0)
  239. raise RuntimeError("Process killed. Error in Process")
  240. # run process at least 20 * interval or until it finishes
  241. mem_process.join(20 * interval)
  242. if (num_measurements > 4) or (interval < 1e-6):
  243. break
  244. # reduce interval
  245. interval /= 10
  246. return max_memory
  247. def start_memory_tracing(
  248. modules_to_trace: Optional[Union[str, Iterable[str]]] = None,
  249. modules_not_to_trace: Optional[Union[str, Iterable[str]]] = None,
  250. events_to_trace: str = "line",
  251. gpus_to_trace: Optional[List[int]] = None,
  252. ) -> MemoryTrace:
  253. """
  254. Setup line-by-line tracing to record rss mem (RAM) at each line of a module or sub-module. See `./benchmark.py` for
  255. usage examples. Current memory consumption is returned using psutil and in particular is the RSS memory "Resident
  256. Set Size” (the non-swapped physical memory the process is using). See
  257. https://psutil.readthedocs.io/en/latest/#psutil.Process.memory_info
  258. Args:
  259. - `modules_to_trace`: (None, string, list/tuple of string) if None, all events are recorded if string or list
  260. of strings: only events from the listed module/sub-module will be recorded (e.g. 'fairseq' or
  261. 'transformers.models.gpt2.modeling_gpt2')
  262. - `modules_not_to_trace`: (None, string, list/tuple of string) if None, no module is avoided if string or list
  263. of strings: events from the listed module/sub-module will not be recorded (e.g. 'torch')
  264. - `events_to_trace`: string or list of string of events to be recorded (see official python doc for
  265. `sys.settrace` for the list of events) default to line
  266. - `gpus_to_trace`: (optional list, default None) list of GPUs to trace. Default to tracing all GPUs
  267. Return:
  268. - `memory_trace` is a list of `UsedMemoryState` for each event (default each line of the traced script).
  269. - `UsedMemoryState` are named tuples with the following fields:
  270. - 'frame': a `Frame` namedtuple (see below) storing information on the current tracing frame (current
  271. file, location in current file)
  272. - 'cpu_memory': CPU RSS memory state *before* executing the line
  273. - 'gpu_memory': GPU used memory *before* executing the line (sum for all GPUs or for only
  274. `gpus_to_trace` if provided)
  275. `Frame` is a namedtuple used by `UsedMemoryState` to list the current frame state. `Frame` has the following
  276. fields: - 'filename' (string): Name of the file currently executed - 'module' (string): Name of the module
  277. currently executed - 'line_number' (int): Number of the line currently executed - 'event' (string): Event that
  278. triggered the tracing (default will be "line") - 'line_text' (string): Text of the line in the python script
  279. """
  280. if is_psutil_available():
  281. process = psutil.Process(os.getpid())
  282. else:
  283. logger.warning(
  284. "Psutil not installed, we won't log CPU memory usage. "
  285. "Install psutil (pip install psutil) to use CPU memory tracing."
  286. )
  287. process = None
  288. if is_py3nvml_available():
  289. try:
  290. nvml.nvmlInit()
  291. devices = list(range(nvml.nvmlDeviceGetCount())) if gpus_to_trace is None else gpus_to_trace
  292. nvml.nvmlShutdown()
  293. except (OSError, nvml.NVMLError):
  294. logger.warning("Error while initializing communication with GPU. We won't perform GPU memory tracing.")
  295. log_gpu = False
  296. else:
  297. log_gpu = is_torch_available() or is_tf_available()
  298. else:
  299. logger.warning(
  300. "py3nvml not installed, we won't log GPU memory usage. "
  301. "Install py3nvml (pip install py3nvml) to use GPU memory tracing."
  302. )
  303. log_gpu = False
  304. memory_trace = []
  305. def traceit(frame, event, args):
  306. """
  307. Tracing method executed before running each line in a module or sub-module Record memory allocated in a list
  308. with debugging information
  309. """
  310. global _is_memory_tracing_enabled
  311. if not _is_memory_tracing_enabled:
  312. return traceit
  313. # Filter events
  314. if events_to_trace is not None:
  315. if isinstance(events_to_trace, str) and event != events_to_trace:
  316. return traceit
  317. elif isinstance(events_to_trace, (list, tuple)) and event not in events_to_trace:
  318. return traceit
  319. if "__name__" not in frame.f_globals:
  320. return traceit
  321. # Filter modules
  322. name = frame.f_globals["__name__"]
  323. if not isinstance(name, str):
  324. return traceit
  325. else:
  326. # Filter whitelist of modules to trace
  327. if modules_to_trace is not None:
  328. if isinstance(modules_to_trace, str) and modules_to_trace not in name:
  329. return traceit
  330. elif isinstance(modules_to_trace, (list, tuple)) and all(m not in name for m in modules_to_trace):
  331. return traceit
  332. # Filter blacklist of modules not to trace
  333. if modules_not_to_trace is not None:
  334. if isinstance(modules_not_to_trace, str) and modules_not_to_trace in name:
  335. return traceit
  336. elif isinstance(modules_not_to_trace, (list, tuple)) and any(m in name for m in modules_not_to_trace):
  337. return traceit
  338. # Record current tracing state (file, location in file...)
  339. lineno = frame.f_lineno
  340. filename = frame.f_globals["__file__"]
  341. if filename.endswith(".pyc") or filename.endswith(".pyo"):
  342. filename = filename[:-1]
  343. line = linecache.getline(filename, lineno).rstrip()
  344. traced_state = Frame(filename, name, lineno, event, line)
  345. # Record current memory state (rss memory) and compute difference with previous memory state
  346. cpu_mem = 0
  347. if process is not None:
  348. mem = process.memory_info()
  349. cpu_mem = mem.rss
  350. gpu_mem = 0
  351. if log_gpu:
  352. # Clear GPU caches
  353. if is_torch_available():
  354. torch_empty_cache()
  355. if is_tf_available():
  356. tf_context.context()._clear_caches() # See https://github.com/tensorflow/tensorflow/issues/20218#issuecomment-416771802
  357. # Sum used memory for all GPUs
  358. nvml.nvmlInit()
  359. for i in devices:
  360. handle = nvml.nvmlDeviceGetHandleByIndex(i)
  361. meminfo = nvml.nvmlDeviceGetMemoryInfo(handle)
  362. gpu_mem += meminfo.used
  363. nvml.nvmlShutdown()
  364. mem_state = UsedMemoryState(traced_state, cpu_mem, gpu_mem)
  365. memory_trace.append(mem_state)
  366. return traceit
  367. sys.settrace(traceit)
  368. global _is_memory_tracing_enabled
  369. _is_memory_tracing_enabled = True
  370. return memory_trace
  371. def stop_memory_tracing(
  372. memory_trace: Optional[MemoryTrace] = None, ignore_released_memory: bool = True
  373. ) -> Optional[MemorySummary]:
  374. """
  375. Stop memory tracing cleanly and return a summary of the memory trace if a trace is given.
  376. Args:
  377. `memory_trace` (optional output of start_memory_tracing, default: None):
  378. memory trace to convert in summary
  379. `ignore_released_memory` (boolean, default: None):
  380. if True we only sum memory increase to compute total memory
  381. Return:
  382. - None if `memory_trace` is None
  383. - `MemorySummary` namedtuple otherwise with the fields:
  384. - `sequential`: a list of `MemoryState` namedtuple (see below) computed from the provided `memory_trace` by
  385. subtracting the memory after executing each line from the memory before executing said line.
  386. - `cumulative`: a list of `MemoryState` namedtuple (see below) with cumulative increase in memory for each
  387. line obtained by summing repeated memory increase for a line if it's executed several times. The list is
  388. sorted from the frame with the largest memory consumption to the frame with the smallest (can be negative
  389. if memory is released)
  390. - `total`: total memory increase during the full tracing as a `Memory` named tuple (see below). Line with
  391. memory release (negative consumption) are ignored if `ignore_released_memory` is `True` (default).
  392. `Memory` named tuple have fields
  393. - `byte` (integer): number of bytes,
  394. - `string` (string): same as human readable string (ex: "3.5MB")
  395. `Frame` are namedtuple used to list the current frame state and have the following fields:
  396. - 'filename' (string): Name of the file currently executed
  397. - 'module' (string): Name of the module currently executed
  398. - 'line_number' (int): Number of the line currently executed
  399. - 'event' (string): Event that triggered the tracing (default will be "line")
  400. - 'line_text' (string): Text of the line in the python script
  401. `MemoryState` are namedtuples listing frame + CPU/GPU memory with the following fields:
  402. - `frame` (`Frame`): the current frame (see above)
  403. - `cpu`: CPU memory consumed at during the current frame as a `Memory` named tuple
  404. - `gpu`: GPU memory consumed at during the current frame as a `Memory` named tuple
  405. - `cpu_gpu`: CPU + GPU memory consumed at during the current frame as a `Memory` named tuple
  406. """
  407. global _is_memory_tracing_enabled
  408. _is_memory_tracing_enabled = False
  409. if memory_trace is not None and len(memory_trace) > 1:
  410. memory_diff_trace = []
  411. memory_curr_trace = []
  412. cumulative_memory_dict = defaultdict(lambda: [0, 0, 0])
  413. for (
  414. (frame, cpu_mem, gpu_mem),
  415. (next_frame, next_cpu_mem, next_gpu_mem),
  416. ) in zip(memory_trace[:-1], memory_trace[1:]):
  417. cpu_mem_inc = next_cpu_mem - cpu_mem
  418. gpu_mem_inc = next_gpu_mem - gpu_mem
  419. cpu_gpu_mem_inc = cpu_mem_inc + gpu_mem_inc
  420. memory_diff_trace.append(
  421. MemoryState(
  422. frame=frame,
  423. cpu=Memory(cpu_mem_inc),
  424. gpu=Memory(gpu_mem_inc),
  425. cpu_gpu=Memory(cpu_gpu_mem_inc),
  426. )
  427. )
  428. memory_curr_trace.append(
  429. MemoryState(
  430. frame=frame,
  431. cpu=Memory(next_cpu_mem),
  432. gpu=Memory(next_gpu_mem),
  433. cpu_gpu=Memory(next_gpu_mem + next_cpu_mem),
  434. )
  435. )
  436. cumulative_memory_dict[frame][0] += cpu_mem_inc
  437. cumulative_memory_dict[frame][1] += gpu_mem_inc
  438. cumulative_memory_dict[frame][2] += cpu_gpu_mem_inc
  439. cumulative_memory = sorted(
  440. cumulative_memory_dict.items(), key=lambda x: x[1][2], reverse=True
  441. ) # order by the total CPU + GPU memory increase
  442. cumulative_memory = [
  443. MemoryState(
  444. frame=frame,
  445. cpu=Memory(cpu_mem_inc),
  446. gpu=Memory(gpu_mem_inc),
  447. cpu_gpu=Memory(cpu_gpu_mem_inc),
  448. )
  449. for frame, (cpu_mem_inc, gpu_mem_inc, cpu_gpu_mem_inc) in cumulative_memory
  450. ]
  451. memory_curr_trace = sorted(memory_curr_trace, key=lambda x: x.cpu_gpu.bytes, reverse=True)
  452. if ignore_released_memory:
  453. total_memory = sum(max(0, step_trace.cpu_gpu.bytes) for step_trace in memory_diff_trace)
  454. else:
  455. total_memory = sum(step_trace.cpu_gpu.bytes for step_trace in memory_diff_trace)
  456. total_memory = Memory(total_memory)
  457. return MemorySummary(
  458. sequential=memory_diff_trace,
  459. cumulative=cumulative_memory,
  460. current=memory_curr_trace,
  461. total=total_memory,
  462. )
  463. return None
  464. def bytes_to_mega_bytes(memory_amount: int) -> int:
  465. """Utility to convert a number of bytes (int) into a number of mega bytes (int)"""
  466. return memory_amount >> 20
  467. class Benchmark(ABC):
  468. """
  469. Benchmarks is a simple but feature-complete benchmarking script to compare memory and time performance of models in
  470. Transformers.
  471. """
  472. args: BenchmarkArguments
  473. configs: PretrainedConfig
  474. framework: str
  475. def __init__(self, args: BenchmarkArguments = None, configs: PretrainedConfig = None):
  476. self.args = args
  477. if configs is None:
  478. self.config_dict = {
  479. model_name: AutoConfig.from_pretrained(model_name) for model_name in self.args.model_names
  480. }
  481. else:
  482. self.config_dict = dict(zip(self.args.model_names, configs))
  483. warnings.warn(
  484. f"The class {self.__class__} is deprecated. Hugging Face Benchmarking utils"
  485. " are deprecated in general and it is advised to use external Benchmarking libraries "
  486. " to benchmark Transformer models.",
  487. FutureWarning,
  488. )
  489. if self.args.memory and os.getenv("TRANSFORMERS_USE_MULTIPROCESSING") == 0:
  490. logger.warning(
  491. "Memory consumption will not be measured accurately if `args.multi_process` is set to `False.` The"
  492. " flag 'TRANSFORMERS_USE_MULTIPROCESSING' should only be disabled for debugging / testing."
  493. )
  494. self._print_fn = None
  495. self._framework_version = None
  496. self._environment_info = None
  497. @property
  498. def print_fn(self):
  499. if self._print_fn is None:
  500. if self.args.log_print:
  501. def print_and_log(*args):
  502. with open(self.args.log_filename, "a") as log_file:
  503. log_file.write("".join(args) + "\n")
  504. print(*args)
  505. self._print_fn = print_and_log
  506. else:
  507. self._print_fn = print
  508. return self._print_fn
  509. @property
  510. @abstractmethod
  511. def framework_version(self):
  512. pass
  513. @abstractmethod
  514. def _inference_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float:
  515. pass
  516. @abstractmethod
  517. def _train_speed(self, model_name: str, batch_size: int, sequence_length: int) -> float:
  518. pass
  519. @abstractmethod
  520. def _inference_memory(
  521. self, model_name: str, batch_size: int, sequence_length: int
  522. ) -> [Memory, Optional[MemorySummary]]:
  523. pass
  524. @abstractmethod
  525. def _train_memory(
  526. self, model_name: str, batch_size: int, sequence_length: int
  527. ) -> [Memory, Optional[MemorySummary]]:
  528. pass
  529. def inference_speed(self, *args, **kwargs) -> float:
  530. return separate_process_wrapper_fn(self._inference_speed, self.args.do_multi_processing)(*args, **kwargs)
  531. def train_speed(self, *args, **kwargs) -> float:
  532. return separate_process_wrapper_fn(self._train_speed, self.args.do_multi_processing)(*args, **kwargs)
  533. def inference_memory(self, *args, **kwargs) -> [Memory, Optional[MemorySummary]]:
  534. return separate_process_wrapper_fn(self._inference_memory, self.args.do_multi_processing)(*args, **kwargs)
  535. def train_memory(self, *args, **kwargs) -> [Memory, Optional[MemorySummary]]:
  536. return separate_process_wrapper_fn(self._train_memory, self.args.do_multi_processing)(*args, **kwargs)
  537. def run(self):
  538. result_dict = {model_name: {} for model_name in self.args.model_names}
  539. inference_result_time = copy.deepcopy(result_dict)
  540. inference_result_memory = copy.deepcopy(result_dict)
  541. train_result_time = copy.deepcopy(result_dict)
  542. train_result_memory = copy.deepcopy(result_dict)
  543. for c, model_name in enumerate(self.args.model_names):
  544. self.print_fn(f"{c + 1} / {len(self.args.model_names)}")
  545. model_dict = {
  546. "bs": self.args.batch_sizes,
  547. "ss": self.args.sequence_lengths,
  548. "result": {i: {} for i in self.args.batch_sizes},
  549. }
  550. inference_result_time[model_name] = copy.deepcopy(model_dict)
  551. inference_result_memory[model_name] = copy.deepcopy(model_dict)
  552. train_result_time[model_name] = copy.deepcopy(model_dict)
  553. train_result_memory[model_name] = copy.deepcopy(model_dict)
  554. inference_summary = train_summary = None
  555. for batch_size in self.args.batch_sizes:
  556. for sequence_length in self.args.sequence_lengths:
  557. if self.args.inference:
  558. if self.args.memory:
  559. memory, inference_summary = self.inference_memory(model_name, batch_size, sequence_length)
  560. inference_result_memory[model_name]["result"][batch_size][sequence_length] = memory
  561. if self.args.speed:
  562. time = self.inference_speed(model_name, batch_size, sequence_length)
  563. inference_result_time[model_name]["result"][batch_size][sequence_length] = time
  564. if self.args.training:
  565. if self.args.memory:
  566. memory, train_summary = self.train_memory(model_name, batch_size, sequence_length)
  567. train_result_memory[model_name]["result"][batch_size][sequence_length] = memory
  568. if self.args.speed:
  569. time = self.train_speed(model_name, batch_size, sequence_length)
  570. train_result_time[model_name]["result"][batch_size][sequence_length] = time
  571. if self.args.inference:
  572. if self.args.speed:
  573. self.print_fn("\n" + 20 * "=" + ("INFERENCE - SPEED - RESULT").center(40) + 20 * "=")
  574. self.print_results(inference_result_time, type_label="Time in s")
  575. self.save_to_csv(inference_result_time, self.args.inference_time_csv_file)
  576. if self.args.is_tpu:
  577. self.print_fn(
  578. "TPU was used for inference. Note that the time after compilation stabilized (after ~10"
  579. " inferences model.forward(..) calls) was measured."
  580. )
  581. if self.args.memory:
  582. self.print_fn("\n" + 20 * "=" + ("INFERENCE - MEMORY - RESULT").center(40) + 20 * "=")
  583. self.print_results(inference_result_memory, type_label="Memory in MB")
  584. self.save_to_csv(inference_result_memory, self.args.inference_memory_csv_file)
  585. if self.args.trace_memory_line_by_line:
  586. self.print_fn("\n" + 20 * "=" + ("INFERENCE - MEMOMRY - LINE BY LINE - SUMMARY").center(40) + 20 * "=")
  587. self.print_memory_trace_statistics(inference_summary)
  588. if self.args.training:
  589. if self.args.speed:
  590. self.print_fn("\n" + 20 * "=" + ("TRAIN - SPEED - RESULTS").center(40) + 20 * "=")
  591. self.print_results(train_result_time, "Time in s")
  592. self.save_to_csv(train_result_time, self.args.train_time_csv_file)
  593. if self.args.is_tpu:
  594. self.print_fn(
  595. "TPU was used for training. Note that the time after compilation stabilized (after ~10 train"
  596. " loss=model.forward(...) + loss.backward() calls) was measured."
  597. )
  598. if self.args.memory:
  599. self.print_fn("\n" + 20 * "=" + ("TRAIN - MEMORY - RESULTS").center(40) + 20 * "=")
  600. self.print_results(train_result_memory, type_label="Memory in MB")
  601. self.save_to_csv(train_result_memory, self.args.train_memory_csv_file)
  602. if self.args.trace_memory_line_by_line:
  603. self.print_fn("\n" + 20 * "=" + ("TRAIN - MEMOMRY - LINE BY LINE - SUMMARY").center(40) + 20 * "=")
  604. self.print_memory_trace_statistics(train_summary)
  605. if self.args.env_print:
  606. self.print_fn("\n" + 20 * "=" + ("ENVIRONMENT INFORMATION").center(40) + 20 * "=")
  607. self.print_fn("\n".join([f"- {prop}: {val}" for prop, val in self.environment_info.items()]) + "\n")
  608. if self.args.save_to_csv:
  609. with open(self.args.env_info_csv_file, mode="w", newline="") as csv_file:
  610. writer = csv.writer(csv_file)
  611. for key, value in self.environment_info.items():
  612. writer.writerow([key, value])
  613. return BenchmarkOutput(
  614. inference_result_time,
  615. inference_result_memory,
  616. train_result_time,
  617. train_result_memory,
  618. inference_summary,
  619. train_summary,
  620. )
  621. @property
  622. def environment_info(self):
  623. if self._environment_info is None:
  624. info = {}
  625. info["transformers_version"] = version
  626. info["framework"] = self.framework
  627. if self.framework == "PyTorch":
  628. info["use_torchscript"] = self.args.torchscript
  629. if self.framework == "TensorFlow":
  630. info["eager_mode"] = self.args.eager_mode
  631. info["use_xla"] = self.args.use_xla
  632. info["framework_version"] = self.framework_version
  633. info["python_version"] = platform.python_version()
  634. info["system"] = platform.system()
  635. info["cpu"] = platform.processor()
  636. info["architecture"] = platform.architecture()[0]
  637. info["date"] = datetime.date(datetime.now())
  638. info["time"] = datetime.time(datetime.now())
  639. info["fp16"] = self.args.fp16
  640. info["use_multiprocessing"] = self.args.do_multi_processing
  641. info["only_pretrain_model"] = self.args.only_pretrain_model
  642. if is_psutil_available():
  643. info["cpu_ram_mb"] = bytes_to_mega_bytes(psutil.virtual_memory().total)
  644. else:
  645. logger.warning(
  646. "Psutil not installed, we won't log available CPU memory. "
  647. "Install psutil (pip install psutil) to log available CPU memory."
  648. )
  649. info["cpu_ram_mb"] = "N/A"
  650. info["use_gpu"] = self.args.is_gpu
  651. if self.args.is_gpu:
  652. info["num_gpus"] = 1 # TODO(PVP) Currently only single GPU is supported
  653. if is_py3nvml_available():
  654. nvml.nvmlInit()
  655. handle = nvml.nvmlDeviceGetHandleByIndex(self.args.device_idx)
  656. info["gpu"] = nvml.nvmlDeviceGetName(handle)
  657. info["gpu_ram_mb"] = bytes_to_mega_bytes(nvml.nvmlDeviceGetMemoryInfo(handle).total)
  658. info["gpu_power_watts"] = nvml.nvmlDeviceGetPowerManagementLimit(handle) / 1000
  659. info["gpu_performance_state"] = nvml.nvmlDeviceGetPerformanceState(handle)
  660. nvml.nvmlShutdown()
  661. else:
  662. logger.warning(
  663. "py3nvml not installed, we won't log GPU memory usage. "
  664. "Install py3nvml (pip install py3nvml) to log information about GPU."
  665. )
  666. info["gpu"] = "N/A"
  667. info["gpu_ram_mb"] = "N/A"
  668. info["gpu_power_watts"] = "N/A"
  669. info["gpu_performance_state"] = "N/A"
  670. info["use_tpu"] = self.args.is_tpu
  671. # TODO(PVP): See if we can add more information about TPU
  672. # see: https://github.com/pytorch/xla/issues/2180
  673. self._environment_info = info
  674. return self._environment_info
  675. def print_results(self, result_dict, type_label):
  676. self.print_fn(80 * "-")
  677. self.print_fn(
  678. "Model Name".center(30) + "Batch Size".center(15) + "Seq Length".center(15) + type_label.center(15)
  679. )
  680. self.print_fn(80 * "-")
  681. for model_name in self.args.model_names:
  682. for batch_size in result_dict[model_name]["bs"]:
  683. for sequence_length in result_dict[model_name]["ss"]:
  684. result = result_dict[model_name]["result"][batch_size][sequence_length]
  685. if isinstance(result, float):
  686. result = round(1000 * result) / 1000
  687. result = "< 0.001" if result == 0.0 else str(result)
  688. else:
  689. result = str(result)
  690. self.print_fn(
  691. model_name[:30].center(30) + str(batch_size).center(15),
  692. str(sequence_length).center(15),
  693. result.center(15),
  694. )
  695. self.print_fn(80 * "-")
  696. def print_memory_trace_statistics(self, summary: MemorySummary):
  697. self.print_fn(
  698. "\nLine by line memory consumption:\n"
  699. + "\n".join(
  700. f"{state.frame.filename}:{state.frame.line_number}: mem {state.cpu_gpu}: {state.frame.line_text}"
  701. for state in summary.sequential
  702. )
  703. )
  704. self.print_fn(
  705. "\nLines with top memory consumption:\n"
  706. + "\n".join(
  707. f"=> {state.frame.filename}:{state.frame.line_number}: mem {state.cpu_gpu}: {state.frame.line_text}"
  708. for state in summary.cumulative[:6]
  709. )
  710. )
  711. self.print_fn(
  712. "\nLines with lowest memory consumption:\n"
  713. + "\n".join(
  714. f"=> {state.frame.filename}:{state.frame.line_number}: mem {state.cpu_gpu}: {state.frame.line_text}"
  715. for state in summary.cumulative[-6:]
  716. )
  717. )
  718. self.print_fn(f"\nTotal memory increase: {summary.total}")
  719. def save_to_csv(self, result_dict, filename):
  720. if not self.args.save_to_csv:
  721. return
  722. self.print_fn("Saving results to csv.")
  723. with open(filename, mode="w") as csv_file:
  724. if len(self.args.model_names) <= 0:
  725. raise ValueError(f"At least 1 model should be defined, but got {self.model_names}")
  726. fieldnames = ["model", "batch_size", "sequence_length"]
  727. writer = csv.DictWriter(csv_file, fieldnames=fieldnames + ["result"])
  728. writer.writeheader()
  729. for model_name in self.args.model_names:
  730. result_dict_model = result_dict[model_name]["result"]
  731. for bs in result_dict_model:
  732. for ss in result_dict_model[bs]:
  733. result_model = result_dict_model[bs][ss]
  734. writer.writerow(
  735. {
  736. "model": model_name,
  737. "batch_size": bs,
  738. "sequence_length": ss,
  739. "result": ("{}" if not isinstance(result_model, float) else "{:.4f}").format(
  740. result_model
  741. ),
  742. }
  743. )