debug_utils.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. # Copyright 2020 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import collections
  15. from .utils import ExplicitEnum, is_torch_available, logging
  16. if is_torch_available():
  17. import torch
  18. logger = logging.get_logger(__name__)
  19. class DebugUnderflowOverflow:
  20. """
  21. This debug class helps detect and understand where the model starts getting very large or very small, and more
  22. importantly `nan` or `inf` weight and activation elements.
  23. There are 2 working modes:
  24. 1. Underflow/overflow detection (default)
  25. 2. Specific batch absolute min/max tracing without detection
  26. Mode 1: Underflow/overflow detection
  27. To activate the underflow/overflow detection, initialize the object with the model :
  28. ```python
  29. debug_overflow = DebugUnderflowOverflow(model)
  30. ```
  31. then run the training as normal and if `nan` or `inf` gets detected in at least one of the weight, input or output
  32. elements this module will throw an exception and will print `max_frames_to_save` frames that lead to this event,
  33. each frame reporting
  34. 1. the fully qualified module name plus the class name whose `forward` was run
  35. 2. the absolute min and max value of all elements for each module weights, and the inputs and output
  36. For example, here is the header and the last few frames in detection report for `google/mt5-small` run in fp16
  37. mixed precision :
  38. ```
  39. Detected inf/nan during batch_number=0
  40. Last 21 forward frames:
  41. abs min abs max metadata
  42. [...]
  43. encoder.block.2.layer.1.DenseReluDense.wi_0 Linear
  44. 2.17e-07 4.50e+00 weight
  45. 1.79e-06 4.65e+00 input[0]
  46. 2.68e-06 3.70e+01 output
  47. encoder.block.2.layer.1.DenseReluDense.wi_1 Linear
  48. 8.08e-07 2.66e+01 weight
  49. 1.79e-06 4.65e+00 input[0]
  50. 1.27e-04 2.37e+02 output
  51. encoder.block.2.layer.1.DenseReluDense.wo Linear
  52. 1.01e-06 6.44e+00 weight
  53. 0.00e+00 9.74e+03 input[0]
  54. 3.18e-04 6.27e+04 output
  55. encoder.block.2.layer.1.DenseReluDense T5DenseGatedGeluDense
  56. 1.79e-06 4.65e+00 input[0]
  57. 3.18e-04 6.27e+04 output
  58. encoder.block.2.layer.1.dropout Dropout
  59. 3.18e-04 6.27e+04 input[0]
  60. 0.00e+00 inf output
  61. ```
  62. You can see here, that `T5DenseGatedGeluDense.forward` resulted in output activations, whose absolute max value was
  63. around 62.7K, which is very close to fp16's top limit of 64K. In the next frame we have `Dropout` which
  64. renormalizes the weights, after it zeroed some of the elements, which pushes the absolute max value to more than
  65. 64K, and we get an overlow.
  66. As you can see it's the previous frames that we need to look into when the numbers start going into very large for
  67. fp16 numbers.
  68. The tracking is done in a forward hook, which gets invoked immediately after `forward` has completed.
  69. By default the last 21 frames are printed. You can change the default to adjust for your needs. For example :
  70. ```python
  71. debug_overflow = DebugUnderflowOverflow(model, max_frames_to_save=100)
  72. ```
  73. To validate that you have set up this debugging feature correctly, and you intend to use it in a training that
  74. may take hours to complete, first run it with normal tracing enabled for one of a few batches as explained in
  75. the next section.
  76. Mode 2. Specific batch absolute min/max tracing without detection
  77. The second work mode is per-batch tracing with the underflow/overflow detection feature turned off.
  78. Let's say you want to watch the absolute min and max values for all the ingredients of each `forward` call of a
  79. given batch, and only do that for batches 1 and 3. Then you instantiate this class as :
  80. ```python
  81. debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1, 3])
  82. ```
  83. And now full batches 1 and 3 will be traced using the same format as explained above. Batches are 0-indexed.
  84. This is helpful if you know that the program starts misbehaving after a certain batch number, so you can
  85. fast-forward right to that area.
  86. Early stopping:
  87. You can also specify the batch number after which to stop the training, with :
  88. ```python
  89. debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1, 3], abort_after_batch_num=3)
  90. ```
  91. This feature is mainly useful in the tracing mode, but you can use it for any mode.
  92. **Performance**:
  93. As this module measures absolute `min`/``max` of each weight of the model on every forward it'll slow the training
  94. down. Therefore remember to turn it off once the debugging needs have been met.
  95. Args:
  96. model (`nn.Module`):
  97. The model to debug.
  98. max_frames_to_save (`int`, *optional*, defaults to 21):
  99. How many frames back to record
  100. trace_batch_nums(`List[int]`, *optional*, defaults to `[]`):
  101. Which batch numbers to trace (turns detection off)
  102. abort_after_batch_num (`int``, *optional*):
  103. Whether to abort after a certain batch number has finished
  104. """
  105. def __init__(self, model, max_frames_to_save=21, trace_batch_nums=[], abort_after_batch_num=None):
  106. self.model = model
  107. self.trace_batch_nums = trace_batch_nums
  108. self.abort_after_batch_num = abort_after_batch_num
  109. # keep a LIFO buffer of frames to dump as soon as inf/nan is encountered to give context to the problem emergence
  110. self.frames = collections.deque([], max_frames_to_save)
  111. self.frame = []
  112. self.batch_number = 0
  113. self.total_calls = 0
  114. self.detected_overflow = False
  115. self.prefix = " "
  116. self.analyse_model()
  117. self.register_forward_hook()
  118. def save_frame(self, frame=None):
  119. if frame is not None:
  120. self.expand_frame(frame)
  121. self.frames.append("\n".join(self.frame))
  122. self.frame = [] # start a new frame
  123. def expand_frame(self, line):
  124. self.frame.append(line)
  125. def trace_frames(self):
  126. print("\n".join(self.frames))
  127. self.frames = []
  128. def reset_saved_frames(self):
  129. self.frames = []
  130. def dump_saved_frames(self):
  131. print(f"\nDetected inf/nan during batch_number={self.batch_number}")
  132. print(f"Last {len(self.frames)} forward frames:")
  133. print(f"{'abs min':8} {'abs max':8} metadata")
  134. print("\n".join(self.frames))
  135. print("\n\n")
  136. self.frames = []
  137. def analyse_model(self):
  138. # extract the fully qualified module names, to be able to report at run time. e.g.:
  139. # encoder.block.2.layer.0.SelfAttention.o
  140. #
  141. # for shared weights only the first shared module name will be registered
  142. self.module_names = {m: name for name, m in self.model.named_modules()}
  143. # self.longest_module_name = max(len(v) for v in self.module_names.values())
  144. def analyse_variable(self, var, ctx):
  145. if torch.is_tensor(var):
  146. self.expand_frame(get_abs_min_max(var, ctx))
  147. if detect_overflow(var, ctx):
  148. self.detected_overflow = True
  149. elif var is None:
  150. self.expand_frame(f"{'None':>17} {ctx}")
  151. else:
  152. self.expand_frame(f"{'not a tensor':>17} {ctx}")
  153. def batch_start_frame(self):
  154. self.expand_frame(f"\n\n{self.prefix} *** Starting batch number={self.batch_number} ***")
  155. self.expand_frame(f"{'abs min':8} {'abs max':8} metadata")
  156. def batch_end_frame(self):
  157. self.expand_frame(f"{self.prefix} *** Finished batch number={self.batch_number-1} ***\n\n")
  158. def create_frame(self, module, input, output):
  159. self.expand_frame(f"{self.prefix} {self.module_names[module]} {module.__class__.__name__}")
  160. # params
  161. for name, p in module.named_parameters(recurse=False):
  162. self.analyse_variable(p, name)
  163. # inputs
  164. if isinstance(input, tuple):
  165. for i, x in enumerate(input):
  166. self.analyse_variable(x, f"input[{i}]")
  167. else:
  168. self.analyse_variable(input, "input")
  169. # outputs
  170. if isinstance(output, tuple):
  171. for i, x in enumerate(output):
  172. # possibly a tuple of tuples
  173. if isinstance(x, tuple):
  174. for j, y in enumerate(x):
  175. self.analyse_variable(y, f"output[{i}][{j}]")
  176. else:
  177. self.analyse_variable(x, f"output[{i}]")
  178. else:
  179. self.analyse_variable(output, "output")
  180. self.save_frame()
  181. def register_forward_hook(self):
  182. self.model.apply(self._register_forward_hook)
  183. def _register_forward_hook(self, module):
  184. module.register_forward_hook(self.forward_hook)
  185. def forward_hook(self, module, input, output):
  186. # - input is a tuple of packed inputs (could be non-Tensors)
  187. # - output could be a Tensor or a tuple of Tensors and non-Tensors
  188. last_frame_of_batch = False
  189. trace_mode = True if self.batch_number in self.trace_batch_nums else False
  190. if trace_mode:
  191. self.reset_saved_frames()
  192. if self.total_calls == 0:
  193. self.batch_start_frame()
  194. self.total_calls += 1
  195. # count batch numbers - the very first forward hook of the batch will be called when the
  196. # batch completes - i.e. it gets called very last - we know this batch has finished
  197. if module == self.model:
  198. self.batch_number += 1
  199. last_frame_of_batch = True
  200. self.create_frame(module, input, output)
  201. # if last_frame_of_batch:
  202. # self.batch_end_frame()
  203. if trace_mode:
  204. self.trace_frames()
  205. if last_frame_of_batch:
  206. self.batch_start_frame()
  207. if self.detected_overflow and not trace_mode:
  208. self.dump_saved_frames()
  209. # now we can abort, as it's pointless to continue running
  210. raise ValueError(
  211. "DebugUnderflowOverflow: inf/nan detected, aborting as there is no point running further. "
  212. "Please scroll up above this traceback to see the activation values prior to this event."
  213. )
  214. # abort after certain batch if requested to do so
  215. if self.abort_after_batch_num is not None and self.batch_number > self.abort_after_batch_num:
  216. raise ValueError(
  217. f"DebugUnderflowOverflow: aborting after {self.batch_number} batches due to"
  218. f" `abort_after_batch_num={self.abort_after_batch_num}` arg"
  219. )
  220. def get_abs_min_max(var, ctx):
  221. abs_var = var.abs()
  222. return f"{abs_var.min():8.2e} {abs_var.max():8.2e} {ctx}"
  223. def detect_overflow(var, ctx):
  224. """
  225. Report whether the tensor contains any `nan` or `inf` entries.
  226. This is useful for detecting overflows/underflows and best to call right after the function that did some math that
  227. modified the tensor in question.
  228. This function contains a few other helper features that you can enable and tweak directly if you want to track
  229. various other things.
  230. Args:
  231. var: the tensor variable to check
  232. ctx: the message to print as a context
  233. Return:
  234. `True` if `inf` or `nan` was detected, `False` otherwise
  235. """
  236. detected = False
  237. if torch.isnan(var).any().item():
  238. detected = True
  239. print(f"{ctx} has nans")
  240. if torch.isinf(var).any().item():
  241. detected = True
  242. print(f"{ctx} has infs")
  243. # if needed to monitor large elements can enable the following
  244. if 0: # and detected:
  245. n100 = var[torch.ge(var.abs(), 100)]
  246. if n100.numel() > 0:
  247. print(f"{ctx}: n100={n100.numel()}")
  248. n1000 = var[torch.ge(var.abs(), 1000)]
  249. if n1000.numel() > 0:
  250. print(f"{ctx}: n1000={n1000.numel()}")
  251. n10000 = var[torch.ge(var.abs(), 10000)]
  252. if n10000.numel() > 0:
  253. print(f"{ctx}: n10000={n10000.numel()}")
  254. if 0:
  255. print(f"min={var.min():9.2e} max={var.max():9.2e}")
  256. if 0:
  257. print(f"min={var.min():9.2e} max={var.max():9.2e} var={var.var():9.2e} mean={var.mean():9.2e} ({ctx})")
  258. return detected
  259. class DebugOption(ExplicitEnum):
  260. UNDERFLOW_OVERFLOW = "underflow_overflow"
  261. TPU_METRICS_DEBUG = "tpu_metrics_debug"