writer.py 46 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208
  1. # mypy: allow-untyped-defs
  2. """Provide an API for writing protocol buffers to event files to be consumed by TensorBoard for visualization."""
  3. import os
  4. import time
  5. from typing import List, Optional, TYPE_CHECKING, Union
  6. import torch
  7. if TYPE_CHECKING:
  8. from matplotlib.figure import Figure
  9. from tensorboard.compat import tf
  10. from tensorboard.compat.proto import event_pb2
  11. from tensorboard.compat.proto.event_pb2 import Event, SessionLog
  12. from tensorboard.plugins.projector.projector_config_pb2 import ProjectorConfig
  13. from tensorboard.summary.writer.event_file_writer import EventFileWriter
  14. from ._convert_np import make_np
  15. from ._embedding import get_embedding_info, make_mat, make_sprite, make_tsv, write_pbtxt
  16. from ._onnx_graph import load_onnx_graph
  17. from ._pytorch_graph import graph
  18. from ._utils import figure_to_image
  19. from .summary import (
  20. audio,
  21. custom_scalars,
  22. histogram,
  23. histogram_raw,
  24. hparams,
  25. image,
  26. image_boxes,
  27. mesh,
  28. pr_curve,
  29. pr_curve_raw,
  30. scalar,
  31. tensor_proto,
  32. text,
  33. video,
  34. )
  35. __all__ = ["FileWriter", "SummaryWriter"]
  36. class FileWriter:
  37. """Writes protocol buffers to event files to be consumed by TensorBoard.
  38. The `FileWriter` class provides a mechanism to create an event file in a
  39. given directory and add summaries and events to it. The class updates the
  40. file contents asynchronously. This allows a training program to call methods
  41. to add data to the file directly from the training loop, without slowing down
  42. training.
  43. """
  44. def __init__(self, log_dir, max_queue=10, flush_secs=120, filename_suffix=""):
  45. """Create a `FileWriter` and an event file.
  46. On construction the writer creates a new event file in `log_dir`.
  47. The other arguments to the constructor control the asynchronous writes to
  48. the event file.
  49. Args:
  50. log_dir: A string. Directory where event file will be written.
  51. max_queue: Integer. Size of the queue for pending events and
  52. summaries before one of the 'add' calls forces a flush to disk.
  53. Default is ten items.
  54. flush_secs: Number. How often, in seconds, to flush the
  55. pending events and summaries to disk. Default is every two minutes.
  56. filename_suffix: A string. Suffix added to all event filenames
  57. in the log_dir directory. More details on filename construction in
  58. tensorboard.summary.writer.event_file_writer.EventFileWriter.
  59. """
  60. # Sometimes PosixPath is passed in and we need to coerce it to
  61. # a string in all cases
  62. # TODO: See if we can remove this in the future if we are
  63. # actually the ones passing in a PosixPath
  64. log_dir = str(log_dir)
  65. self.event_writer = EventFileWriter(
  66. log_dir, max_queue, flush_secs, filename_suffix
  67. )
  68. def get_logdir(self):
  69. """Return the directory where event file will be written."""
  70. return self.event_writer.get_logdir()
  71. def add_event(self, event, step=None, walltime=None):
  72. """Add an event to the event file.
  73. Args:
  74. event: An `Event` protocol buffer.
  75. step: Number. Optional global step value for training process
  76. to record with the event.
  77. walltime: float. Optional walltime to override the default (current)
  78. walltime (from time.time()) seconds after epoch
  79. """
  80. event.wall_time = time.time() if walltime is None else walltime
  81. if step is not None:
  82. # Make sure step is converted from numpy or other formats
  83. # since protobuf might not convert depending on version
  84. event.step = int(step)
  85. self.event_writer.add_event(event)
  86. def add_summary(self, summary, global_step=None, walltime=None):
  87. """Add a `Summary` protocol buffer to the event file.
  88. This method wraps the provided summary in an `Event` protocol buffer
  89. and adds it to the event file.
  90. Args:
  91. summary: A `Summary` protocol buffer.
  92. global_step: Number. Optional global step value for training process
  93. to record with the summary.
  94. walltime: float. Optional walltime to override the default (current)
  95. walltime (from time.time()) seconds after epoch
  96. """
  97. event = event_pb2.Event(summary=summary)
  98. self.add_event(event, global_step, walltime)
  99. def add_graph(self, graph_profile, walltime=None):
  100. """Add a `Graph` and step stats protocol buffer to the event file.
  101. Args:
  102. graph_profile: A `Graph` and step stats protocol buffer.
  103. walltime: float. Optional walltime to override the default (current)
  104. walltime (from time.time()) seconds after epoch
  105. """
  106. graph = graph_profile[0]
  107. stepstats = graph_profile[1]
  108. event = event_pb2.Event(graph_def=graph.SerializeToString())
  109. self.add_event(event, None, walltime)
  110. trm = event_pb2.TaggedRunMetadata(
  111. tag="step1", run_metadata=stepstats.SerializeToString()
  112. )
  113. event = event_pb2.Event(tagged_run_metadata=trm)
  114. self.add_event(event, None, walltime)
  115. def add_onnx_graph(self, graph, walltime=None):
  116. """Add a `Graph` protocol buffer to the event file.
  117. Args:
  118. graph: A `Graph` protocol buffer.
  119. walltime: float. Optional walltime to override the default (current)
  120. _get_file_writerfrom time.time())
  121. """
  122. event = event_pb2.Event(graph_def=graph.SerializeToString())
  123. self.add_event(event, None, walltime)
  124. def flush(self):
  125. """Flushes the event file to disk.
  126. Call this method to make sure that all pending events have been written to
  127. disk.
  128. """
  129. self.event_writer.flush()
  130. def close(self):
  131. """Flushes the event file to disk and close the file.
  132. Call this method when you do not need the summary writer anymore.
  133. """
  134. self.event_writer.close()
  135. def reopen(self):
  136. """Reopens the EventFileWriter.
  137. Can be called after `close()` to add more events in the same directory.
  138. The events will go into a new events file.
  139. Does nothing if the EventFileWriter was not closed.
  140. """
  141. self.event_writer.reopen()
  142. class SummaryWriter:
  143. """Writes entries directly to event files in the log_dir to be consumed by TensorBoard.
  144. The `SummaryWriter` class provides a high-level API to create an event file
  145. in a given directory and add summaries and events to it. The class updates the
  146. file contents asynchronously. This allows a training program to call methods
  147. to add data to the file directly from the training loop, without slowing down
  148. training.
  149. """
  150. def __init__(
  151. self,
  152. log_dir=None,
  153. comment="",
  154. purge_step=None,
  155. max_queue=10,
  156. flush_secs=120,
  157. filename_suffix="",
  158. ):
  159. """Create a `SummaryWriter` that will write out events and summaries to the event file.
  160. Args:
  161. log_dir (str): Save directory location. Default is
  162. runs/**CURRENT_DATETIME_HOSTNAME**, which changes after each run.
  163. Use hierarchical folder structure to compare
  164. between runs easily. e.g. pass in 'runs/exp1', 'runs/exp2', etc.
  165. for each new experiment to compare across them.
  166. comment (str): Comment log_dir suffix appended to the default
  167. ``log_dir``. If ``log_dir`` is assigned, this argument has no effect.
  168. purge_step (int):
  169. When logging crashes at step :math:`T+X` and restarts at step :math:`T`,
  170. any events whose global_step larger or equal to :math:`T` will be
  171. purged and hidden from TensorBoard.
  172. Note that crashed and resumed experiments should have the same ``log_dir``.
  173. max_queue (int): Size of the queue for pending events and
  174. summaries before one of the 'add' calls forces a flush to disk.
  175. Default is ten items.
  176. flush_secs (int): How often, in seconds, to flush the
  177. pending events and summaries to disk. Default is every two minutes.
  178. filename_suffix (str): Suffix added to all event filenames in
  179. the log_dir directory. More details on filename construction in
  180. tensorboard.summary.writer.event_file_writer.EventFileWriter.
  181. Examples::
  182. from torch.utils.tensorboard import SummaryWriter
  183. # create a summary writer with automatically generated folder name.
  184. writer = SummaryWriter()
  185. # folder location: runs/May04_22-14-54_s-MacBook-Pro.local/
  186. # create a summary writer using the specified folder name.
  187. writer = SummaryWriter("my_experiment")
  188. # folder location: my_experiment
  189. # create a summary writer with comment appended.
  190. writer = SummaryWriter(comment="LR_0.1_BATCH_16")
  191. # folder location: runs/May04_22-14-54_s-MacBook-Pro.localLR_0.1_BATCH_16/
  192. """
  193. torch._C._log_api_usage_once("tensorboard.create.summarywriter")
  194. if not log_dir:
  195. import socket
  196. from datetime import datetime
  197. current_time = datetime.now().strftime("%b%d_%H-%M-%S")
  198. log_dir = os.path.join(
  199. "runs", current_time + "_" + socket.gethostname() + comment
  200. )
  201. self.log_dir = log_dir
  202. self.purge_step = purge_step
  203. self.max_queue = max_queue
  204. self.flush_secs = flush_secs
  205. self.filename_suffix = filename_suffix
  206. # Initialize the file writers, but they can be cleared out on close
  207. # and recreated later as needed.
  208. self.file_writer = self.all_writers = None
  209. self._get_file_writer()
  210. # Create default bins for histograms, see generate_testdata.py in tensorflow/tensorboard
  211. v = 1e-12
  212. buckets = []
  213. neg_buckets = []
  214. while v < 1e20:
  215. buckets.append(v)
  216. neg_buckets.append(-v)
  217. v *= 1.1
  218. self.default_bins = neg_buckets[::-1] + [0] + buckets
  219. def _get_file_writer(self):
  220. """Return the default FileWriter instance. Recreates it if closed."""
  221. if self.all_writers is None or self.file_writer is None:
  222. self.file_writer = FileWriter(
  223. self.log_dir, self.max_queue, self.flush_secs, self.filename_suffix
  224. )
  225. self.all_writers = {self.file_writer.get_logdir(): self.file_writer}
  226. if self.purge_step is not None:
  227. most_recent_step = self.purge_step
  228. self.file_writer.add_event(
  229. Event(step=most_recent_step, file_version="brain.Event:2")
  230. )
  231. self.file_writer.add_event(
  232. Event(
  233. step=most_recent_step,
  234. session_log=SessionLog(status=SessionLog.START),
  235. )
  236. )
  237. self.purge_step = None
  238. return self.file_writer
  239. def get_logdir(self):
  240. """Return the directory where event files will be written."""
  241. return self.log_dir
  242. def add_hparams(
  243. self,
  244. hparam_dict,
  245. metric_dict,
  246. hparam_domain_discrete=None,
  247. run_name=None,
  248. global_step=None,
  249. ):
  250. """Add a set of hyperparameters to be compared in TensorBoard.
  251. Args:
  252. hparam_dict (dict): Each key-value pair in the dictionary is the
  253. name of the hyper parameter and it's corresponding value.
  254. The type of the value can be one of `bool`, `string`, `float`,
  255. `int`, or `None`.
  256. metric_dict (dict): Each key-value pair in the dictionary is the
  257. name of the metric and it's corresponding value. Note that the key used
  258. here should be unique in the tensorboard record. Otherwise the value
  259. you added by ``add_scalar`` will be displayed in hparam plugin. In most
  260. cases, this is unwanted.
  261. hparam_domain_discrete: (Optional[Dict[str, List[Any]]]) A dictionary that
  262. contains names of the hyperparameters and all discrete values they can hold
  263. run_name (str): Name of the run, to be included as part of the logdir.
  264. If unspecified, will use current timestamp.
  265. global_step (int): Global step value to record
  266. Examples::
  267. from torch.utils.tensorboard import SummaryWriter
  268. with SummaryWriter() as w:
  269. for i in range(5):
  270. w.add_hparams({'lr': 0.1*i, 'bsize': i},
  271. {'hparam/accuracy': 10*i, 'hparam/loss': 10*i})
  272. Expected result:
  273. .. image:: _static/img/tensorboard/add_hparam.png
  274. :scale: 50 %
  275. """
  276. torch._C._log_api_usage_once("tensorboard.logging.add_hparams")
  277. if type(hparam_dict) is not dict or type(metric_dict) is not dict:
  278. raise TypeError("hparam_dict and metric_dict should be dictionary.")
  279. exp, ssi, sei = hparams(hparam_dict, metric_dict, hparam_domain_discrete)
  280. if not run_name:
  281. run_name = str(time.time())
  282. logdir = os.path.join(self._get_file_writer().get_logdir(), run_name)
  283. with SummaryWriter(log_dir=logdir) as w_hp:
  284. w_hp.file_writer.add_summary(exp, global_step)
  285. w_hp.file_writer.add_summary(ssi, global_step)
  286. w_hp.file_writer.add_summary(sei, global_step)
  287. for k, v in metric_dict.items():
  288. w_hp.add_scalar(k, v, global_step)
  289. def add_scalar(
  290. self,
  291. tag,
  292. scalar_value,
  293. global_step=None,
  294. walltime=None,
  295. new_style=False,
  296. double_precision=False,
  297. ):
  298. """Add scalar data to summary.
  299. Args:
  300. tag (str): Data identifier
  301. scalar_value (float or string/blobname): Value to save
  302. global_step (int): Global step value to record
  303. walltime (float): Optional override default walltime (time.time())
  304. with seconds after epoch of event
  305. new_style (boolean): Whether to use new style (tensor field) or old
  306. style (simple_value field). New style could lead to faster data loading.
  307. Examples::
  308. from torch.utils.tensorboard import SummaryWriter
  309. writer = SummaryWriter()
  310. x = range(100)
  311. for i in x:
  312. writer.add_scalar('y=2x', i * 2, i)
  313. writer.close()
  314. Expected result:
  315. .. image:: _static/img/tensorboard/add_scalar.png
  316. :scale: 50 %
  317. """
  318. torch._C._log_api_usage_once("tensorboard.logging.add_scalar")
  319. summary = scalar(
  320. tag, scalar_value, new_style=new_style, double_precision=double_precision
  321. )
  322. self._get_file_writer().add_summary(summary, global_step, walltime)
  323. def add_scalars(self, main_tag, tag_scalar_dict, global_step=None, walltime=None):
  324. """Add many scalar data to summary.
  325. Args:
  326. main_tag (str): The parent name for the tags
  327. tag_scalar_dict (dict): Key-value pair storing the tag and corresponding values
  328. global_step (int): Global step value to record
  329. walltime (float): Optional override default walltime (time.time())
  330. seconds after epoch of event
  331. Examples::
  332. from torch.utils.tensorboard import SummaryWriter
  333. writer = SummaryWriter()
  334. r = 5
  335. for i in range(100):
  336. writer.add_scalars('run_14h', {'xsinx':i*np.sin(i/r),
  337. 'xcosx':i*np.cos(i/r),
  338. 'tanx': np.tan(i/r)}, i)
  339. writer.close()
  340. # This call adds three values to the same scalar plot with the tag
  341. # 'run_14h' in TensorBoard's scalar section.
  342. Expected result:
  343. .. image:: _static/img/tensorboard/add_scalars.png
  344. :scale: 50 %
  345. """
  346. torch._C._log_api_usage_once("tensorboard.logging.add_scalars")
  347. walltime = time.time() if walltime is None else walltime
  348. fw_logdir = self._get_file_writer().get_logdir()
  349. for tag, scalar_value in tag_scalar_dict.items():
  350. fw_tag = fw_logdir + "/" + main_tag.replace("/", "_") + "_" + tag
  351. assert self.all_writers is not None
  352. if fw_tag in self.all_writers.keys():
  353. fw = self.all_writers[fw_tag]
  354. else:
  355. fw = FileWriter(
  356. fw_tag, self.max_queue, self.flush_secs, self.filename_suffix
  357. )
  358. self.all_writers[fw_tag] = fw
  359. fw.add_summary(scalar(main_tag, scalar_value), global_step, walltime)
  360. def add_tensor(
  361. self,
  362. tag,
  363. tensor,
  364. global_step=None,
  365. walltime=None,
  366. ):
  367. """Add tensor data to summary.
  368. Args:
  369. tag (str): Data identifier
  370. tensor (torch.Tensor): tensor to save
  371. global_step (int): Global step value to record
  372. Examples::
  373. from torch.utils.tensorboard import SummaryWriter
  374. writer = SummaryWriter()
  375. x = torch.tensor([1,2,3])
  376. writer.add_scalar('x', x)
  377. writer.close()
  378. Expected result:
  379. Summary::tensor::float_val [1,2,3]
  380. ::tensor::shape [3]
  381. ::tag 'x'
  382. """
  383. torch._C._log_api_usage_once("tensorboard.logging.add_tensor")
  384. summary = tensor_proto(tag, tensor)
  385. self._get_file_writer().add_summary(summary, global_step, walltime)
  386. def add_histogram(
  387. self,
  388. tag,
  389. values,
  390. global_step=None,
  391. bins="tensorflow",
  392. walltime=None,
  393. max_bins=None,
  394. ):
  395. """Add histogram to summary.
  396. Args:
  397. tag (str): Data identifier
  398. values (torch.Tensor, numpy.ndarray, or string/blobname): Values to build histogram
  399. global_step (int): Global step value to record
  400. bins (str): One of {'tensorflow','auto', 'fd', ...}. This determines how the bins are made. You can find
  401. other options in: https://docs.scipy.org/doc/numpy/reference/generated/numpy.histogram.html
  402. walltime (float): Optional override default walltime (time.time())
  403. seconds after epoch of event
  404. Examples::
  405. from torch.utils.tensorboard import SummaryWriter
  406. import numpy as np
  407. writer = SummaryWriter()
  408. for i in range(10):
  409. x = np.random.random(1000)
  410. writer.add_histogram('distribution centers', x + i, i)
  411. writer.close()
  412. Expected result:
  413. .. image:: _static/img/tensorboard/add_histogram.png
  414. :scale: 50 %
  415. """
  416. torch._C._log_api_usage_once("tensorboard.logging.add_histogram")
  417. if isinstance(bins, str) and bins == "tensorflow":
  418. bins = self.default_bins
  419. self._get_file_writer().add_summary(
  420. histogram(tag, values, bins, max_bins=max_bins), global_step, walltime
  421. )
  422. def add_histogram_raw(
  423. self,
  424. tag,
  425. min,
  426. max,
  427. num,
  428. sum,
  429. sum_squares,
  430. bucket_limits,
  431. bucket_counts,
  432. global_step=None,
  433. walltime=None,
  434. ):
  435. """Add histogram with raw data.
  436. Args:
  437. tag (str): Data identifier
  438. min (float or int): Min value
  439. max (float or int): Max value
  440. num (int): Number of values
  441. sum (float or int): Sum of all values
  442. sum_squares (float or int): Sum of squares for all values
  443. bucket_limits (torch.Tensor, numpy.ndarray): Upper value per bucket.
  444. The number of elements of it should be the same as `bucket_counts`.
  445. bucket_counts (torch.Tensor, numpy.ndarray): Number of values per bucket
  446. global_step (int): Global step value to record
  447. walltime (float): Optional override default walltime (time.time())
  448. seconds after epoch of event
  449. see: https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/histogram/README.md
  450. Examples::
  451. from torch.utils.tensorboard import SummaryWriter
  452. import numpy as np
  453. writer = SummaryWriter()
  454. dummy_data = []
  455. for idx, value in enumerate(range(50)):
  456. dummy_data += [idx + 0.001] * value
  457. bins = list(range(50+2))
  458. bins = np.array(bins)
  459. values = np.array(dummy_data).astype(float).reshape(-1)
  460. counts, limits = np.histogram(values, bins=bins)
  461. sum_sq = values.dot(values)
  462. writer.add_histogram_raw(
  463. tag='histogram_with_raw_data',
  464. min=values.min(),
  465. max=values.max(),
  466. num=len(values),
  467. sum=values.sum(),
  468. sum_squares=sum_sq,
  469. bucket_limits=limits[1:].tolist(),
  470. bucket_counts=counts.tolist(),
  471. global_step=0)
  472. writer.close()
  473. Expected result:
  474. .. image:: _static/img/tensorboard/add_histogram_raw.png
  475. :scale: 50 %
  476. """
  477. torch._C._log_api_usage_once("tensorboard.logging.add_histogram_raw")
  478. if len(bucket_limits) != len(bucket_counts):
  479. raise ValueError(
  480. "len(bucket_limits) != len(bucket_counts), see the document."
  481. )
  482. self._get_file_writer().add_summary(
  483. histogram_raw(
  484. tag, min, max, num, sum, sum_squares, bucket_limits, bucket_counts
  485. ),
  486. global_step,
  487. walltime,
  488. )
  489. def add_image(
  490. self, tag, img_tensor, global_step=None, walltime=None, dataformats="CHW"
  491. ):
  492. """Add image data to summary.
  493. Note that this requires the ``pillow`` package.
  494. Args:
  495. tag (str): Data identifier
  496. img_tensor (torch.Tensor, numpy.ndarray, or string/blobname): Image data
  497. global_step (int): Global step value to record
  498. walltime (float): Optional override default walltime (time.time())
  499. seconds after epoch of event
  500. dataformats (str): Image data format specification of the form
  501. CHW, HWC, HW, WH, etc.
  502. Shape:
  503. img_tensor: Default is :math:`(3, H, W)`. You can use ``torchvision.utils.make_grid()`` to
  504. convert a batch of tensor into 3xHxW format or call ``add_images`` and let us do the job.
  505. Tensor with :math:`(1, H, W)`, :math:`(H, W)`, :math:`(H, W, 3)` is also suitable as long as
  506. corresponding ``dataformats`` argument is passed, e.g. ``CHW``, ``HWC``, ``HW``.
  507. Examples::
  508. from torch.utils.tensorboard import SummaryWriter
  509. import numpy as np
  510. img = np.zeros((3, 100, 100))
  511. img[0] = np.arange(0, 10000).reshape(100, 100) / 10000
  512. img[1] = 1 - np.arange(0, 10000).reshape(100, 100) / 10000
  513. img_HWC = np.zeros((100, 100, 3))
  514. img_HWC[:, :, 0] = np.arange(0, 10000).reshape(100, 100) / 10000
  515. img_HWC[:, :, 1] = 1 - np.arange(0, 10000).reshape(100, 100) / 10000
  516. writer = SummaryWriter()
  517. writer.add_image('my_image', img, 0)
  518. # If you have non-default dimension setting, set the dataformats argument.
  519. writer.add_image('my_image_HWC', img_HWC, 0, dataformats='HWC')
  520. writer.close()
  521. Expected result:
  522. .. image:: _static/img/tensorboard/add_image.png
  523. :scale: 50 %
  524. """
  525. torch._C._log_api_usage_once("tensorboard.logging.add_image")
  526. self._get_file_writer().add_summary(
  527. image(tag, img_tensor, dataformats=dataformats), global_step, walltime
  528. )
  529. def add_images(
  530. self, tag, img_tensor, global_step=None, walltime=None, dataformats="NCHW"
  531. ):
  532. """Add batched image data to summary.
  533. Note that this requires the ``pillow`` package.
  534. Args:
  535. tag (str): Data identifier
  536. img_tensor (torch.Tensor, numpy.ndarray, or string/blobname): Image data
  537. global_step (int): Global step value to record
  538. walltime (float): Optional override default walltime (time.time())
  539. seconds after epoch of event
  540. dataformats (str): Image data format specification of the form
  541. NCHW, NHWC, CHW, HWC, HW, WH, etc.
  542. Shape:
  543. img_tensor: Default is :math:`(N, 3, H, W)`. If ``dataformats`` is specified, other shape will be
  544. accepted. e.g. NCHW or NHWC.
  545. Examples::
  546. from torch.utils.tensorboard import SummaryWriter
  547. import numpy as np
  548. img_batch = np.zeros((16, 3, 100, 100))
  549. for i in range(16):
  550. img_batch[i, 0] = np.arange(0, 10000).reshape(100, 100) / 10000 / 16 * i
  551. img_batch[i, 1] = (1 - np.arange(0, 10000).reshape(100, 100) / 10000) / 16 * i
  552. writer = SummaryWriter()
  553. writer.add_images('my_image_batch', img_batch, 0)
  554. writer.close()
  555. Expected result:
  556. .. image:: _static/img/tensorboard/add_images.png
  557. :scale: 30 %
  558. """
  559. torch._C._log_api_usage_once("tensorboard.logging.add_images")
  560. self._get_file_writer().add_summary(
  561. image(tag, img_tensor, dataformats=dataformats), global_step, walltime
  562. )
  563. def add_image_with_boxes(
  564. self,
  565. tag,
  566. img_tensor,
  567. box_tensor,
  568. global_step=None,
  569. walltime=None,
  570. rescale=1,
  571. dataformats="CHW",
  572. labels=None,
  573. ):
  574. """Add image and draw bounding boxes on the image.
  575. Args:
  576. tag (str): Data identifier
  577. img_tensor (torch.Tensor, numpy.ndarray, or string/blobname): Image data
  578. box_tensor (torch.Tensor, numpy.ndarray, or string/blobname): Box data (for detected objects)
  579. box should be represented as [x1, y1, x2, y2].
  580. global_step (int): Global step value to record
  581. walltime (float): Optional override default walltime (time.time())
  582. seconds after epoch of event
  583. rescale (float): Optional scale override
  584. dataformats (str): Image data format specification of the form
  585. NCHW, NHWC, CHW, HWC, HW, WH, etc.
  586. labels (list of string): The label to be shown for each bounding box.
  587. Shape:
  588. img_tensor: Default is :math:`(3, H, W)`. It can be specified with ``dataformats`` argument.
  589. e.g. CHW or HWC
  590. box_tensor: (torch.Tensor, numpy.ndarray, or string/blobname): NX4, where N is the number of
  591. boxes and each 4 elements in a row represents (xmin, ymin, xmax, ymax).
  592. """
  593. torch._C._log_api_usage_once("tensorboard.logging.add_image_with_boxes")
  594. if labels is not None:
  595. if isinstance(labels, str):
  596. labels = [labels]
  597. if len(labels) != box_tensor.shape[0]:
  598. labels = None
  599. self._get_file_writer().add_summary(
  600. image_boxes(
  601. tag,
  602. img_tensor,
  603. box_tensor,
  604. rescale=rescale,
  605. dataformats=dataformats,
  606. labels=labels,
  607. ),
  608. global_step,
  609. walltime,
  610. )
  611. def add_figure(
  612. self,
  613. tag: str,
  614. figure: Union["Figure", List["Figure"]],
  615. global_step: Optional[int] = None,
  616. close: bool = True,
  617. walltime: Optional[float] = None,
  618. ) -> None:
  619. """Render matplotlib figure into an image and add it to summary.
  620. Note that this requires the ``matplotlib`` package.
  621. Args:
  622. tag: Data identifier
  623. figure: Figure or a list of figures
  624. global_step: Global step value to record
  625. close: Flag to automatically close the figure
  626. walltime: Optional override default walltime (time.time())
  627. seconds after epoch of event
  628. """
  629. torch._C._log_api_usage_once("tensorboard.logging.add_figure")
  630. if isinstance(figure, list):
  631. self.add_image(
  632. tag,
  633. figure_to_image(figure, close),
  634. global_step,
  635. walltime,
  636. dataformats="NCHW",
  637. )
  638. else:
  639. self.add_image(
  640. tag,
  641. figure_to_image(figure, close),
  642. global_step,
  643. walltime,
  644. dataformats="CHW",
  645. )
  646. def add_video(self, tag, vid_tensor, global_step=None, fps=4, walltime=None):
  647. """Add video data to summary.
  648. Note that this requires the ``moviepy`` package.
  649. Args:
  650. tag (str): Data identifier
  651. vid_tensor (torch.Tensor): Video data
  652. global_step (int): Global step value to record
  653. fps (float or int): Frames per second
  654. walltime (float): Optional override default walltime (time.time())
  655. seconds after epoch of event
  656. Shape:
  657. vid_tensor: :math:`(N, T, C, H, W)`. The values should lie in [0, 255] for type `uint8` or [0, 1] for type `float`.
  658. """
  659. torch._C._log_api_usage_once("tensorboard.logging.add_video")
  660. self._get_file_writer().add_summary(
  661. video(tag, vid_tensor, fps), global_step, walltime
  662. )
  663. def add_audio(
  664. self, tag, snd_tensor, global_step=None, sample_rate=44100, walltime=None
  665. ):
  666. """Add audio data to summary.
  667. Args:
  668. tag (str): Data identifier
  669. snd_tensor (torch.Tensor): Sound data
  670. global_step (int): Global step value to record
  671. sample_rate (int): sample rate in Hz
  672. walltime (float): Optional override default walltime (time.time())
  673. seconds after epoch of event
  674. Shape:
  675. snd_tensor: :math:`(1, L)`. The values should lie between [-1, 1].
  676. """
  677. torch._C._log_api_usage_once("tensorboard.logging.add_audio")
  678. self._get_file_writer().add_summary(
  679. audio(tag, snd_tensor, sample_rate=sample_rate), global_step, walltime
  680. )
  681. def add_text(self, tag, text_string, global_step=None, walltime=None):
  682. """Add text data to summary.
  683. Args:
  684. tag (str): Data identifier
  685. text_string (str): String to save
  686. global_step (int): Global step value to record
  687. walltime (float): Optional override default walltime (time.time())
  688. seconds after epoch of event
  689. Examples::
  690. writer.add_text('lstm', 'This is an lstm', 0)
  691. writer.add_text('rnn', 'This is an rnn', 10)
  692. """
  693. torch._C._log_api_usage_once("tensorboard.logging.add_text")
  694. self._get_file_writer().add_summary(
  695. text(tag, text_string), global_step, walltime
  696. )
  697. def add_onnx_graph(self, prototxt):
  698. torch._C._log_api_usage_once("tensorboard.logging.add_onnx_graph")
  699. self._get_file_writer().add_onnx_graph(load_onnx_graph(prototxt))
  700. def add_graph(
  701. self, model, input_to_model=None, verbose=False, use_strict_trace=True
  702. ):
  703. """Add graph data to summary.
  704. Args:
  705. model (torch.nn.Module): Model to draw.
  706. input_to_model (torch.Tensor or list of torch.Tensor): A variable or a tuple of
  707. variables to be fed.
  708. verbose (bool): Whether to print graph structure in console.
  709. use_strict_trace (bool): Whether to pass keyword argument `strict` to
  710. `torch.jit.trace`. Pass False when you want the tracer to
  711. record your mutable container types (list, dict)
  712. """
  713. torch._C._log_api_usage_once("tensorboard.logging.add_graph")
  714. # A valid PyTorch model should have a 'forward' method
  715. self._get_file_writer().add_graph(
  716. graph(model, input_to_model, verbose, use_strict_trace)
  717. )
  718. @staticmethod
  719. def _encode(rawstr):
  720. # I'd use urllib but, I'm unsure about the differences from python3 to python2, etc.
  721. retval = rawstr
  722. retval = retval.replace("%", f"%{ord('%'):02x}")
  723. retval = retval.replace("/", f"%{ord('/'):02x}")
  724. retval = retval.replace("\\", "%%%02x" % (ord("\\"))) # noqa: UP031
  725. return retval
  726. def add_embedding(
  727. self,
  728. mat,
  729. metadata=None,
  730. label_img=None,
  731. global_step=None,
  732. tag="default",
  733. metadata_header=None,
  734. ):
  735. """Add embedding projector data to summary.
  736. Args:
  737. mat (torch.Tensor or numpy.ndarray): A matrix which each row is the feature vector of the data point
  738. metadata (list): A list of labels, each element will be converted to string
  739. label_img (torch.Tensor): Images correspond to each data point
  740. global_step (int): Global step value to record
  741. tag (str): Name for the embedding
  742. metadata_header (list): A list of headers for multi-column metadata. If given, each metadata must be
  743. a list with values corresponding to headers.
  744. Shape:
  745. mat: :math:`(N, D)`, where N is number of data and D is feature dimension
  746. label_img: :math:`(N, C, H, W)`
  747. Examples::
  748. import keyword
  749. import torch
  750. meta = []
  751. while len(meta)<100:
  752. meta = meta+keyword.kwlist # get some strings
  753. meta = meta[:100]
  754. for i, v in enumerate(meta):
  755. meta[i] = v+str(i)
  756. label_img = torch.rand(100, 3, 10, 32)
  757. for i in range(100):
  758. label_img[i]*=i/100.0
  759. writer.add_embedding(torch.randn(100, 5), metadata=meta, label_img=label_img)
  760. writer.add_embedding(torch.randn(100, 5), label_img=label_img)
  761. writer.add_embedding(torch.randn(100, 5), metadata=meta)
  762. .. note::
  763. Categorical (i.e. non-numeric) metadata cannot have more than 50 unique values if they are to be used for
  764. coloring in the embedding projector.
  765. """
  766. torch._C._log_api_usage_once("tensorboard.logging.add_embedding")
  767. mat = make_np(mat)
  768. if global_step is None:
  769. global_step = 0
  770. # clear pbtxt?
  771. # Maybe we should encode the tag so slashes don't trip us up?
  772. # I don't think this will mess us up, but better safe than sorry.
  773. subdir = f"{str(global_step).zfill(5)}/{self._encode(tag)}"
  774. save_path = os.path.join(self._get_file_writer().get_logdir(), subdir)
  775. fs = tf.io.gfile
  776. if fs.exists(save_path):
  777. if fs.isdir(save_path):
  778. print(
  779. "warning: Embedding dir exists, did you set global_step for add_embedding()?"
  780. )
  781. else:
  782. raise NotADirectoryError(
  783. f"Path: `{save_path}` exists, but is a file. Cannot proceed."
  784. )
  785. else:
  786. fs.makedirs(save_path)
  787. if metadata is not None:
  788. assert mat.shape[0] == len(
  789. metadata
  790. ), "#labels should equal with #data points"
  791. make_tsv(metadata, save_path, metadata_header=metadata_header)
  792. if label_img is not None:
  793. assert (
  794. mat.shape[0] == label_img.shape[0]
  795. ), "#images should equal with #data points"
  796. make_sprite(label_img, save_path)
  797. assert (
  798. mat.ndim == 2
  799. ), "mat should be 2D, where mat.size(0) is the number of data points"
  800. make_mat(mat, save_path)
  801. # Filesystem doesn't necessarily have append semantics, so we store an
  802. # internal buffer to append to and re-write whole file after each
  803. # embedding is added
  804. if not hasattr(self, "_projector_config"):
  805. self._projector_config = ProjectorConfig()
  806. embedding_info = get_embedding_info(
  807. metadata, label_img, subdir, global_step, tag
  808. )
  809. self._projector_config.embeddings.extend([embedding_info])
  810. from google.protobuf import text_format
  811. config_pbtxt = text_format.MessageToString(self._projector_config)
  812. write_pbtxt(self._get_file_writer().get_logdir(), config_pbtxt)
  813. def add_pr_curve(
  814. self,
  815. tag,
  816. labels,
  817. predictions,
  818. global_step=None,
  819. num_thresholds=127,
  820. weights=None,
  821. walltime=None,
  822. ):
  823. """Add precision recall curve.
  824. Plotting a precision-recall curve lets you understand your model's
  825. performance under different threshold settings. With this function,
  826. you provide the ground truth labeling (T/F) and prediction confidence
  827. (usually the output of your model) for each target. The TensorBoard UI
  828. will let you choose the threshold interactively.
  829. Args:
  830. tag (str): Data identifier
  831. labels (torch.Tensor, numpy.ndarray, or string/blobname):
  832. Ground truth data. Binary label for each element.
  833. predictions (torch.Tensor, numpy.ndarray, or string/blobname):
  834. The probability that an element be classified as true.
  835. Value should be in [0, 1]
  836. global_step (int): Global step value to record
  837. num_thresholds (int): Number of thresholds used to draw the curve.
  838. walltime (float): Optional override default walltime (time.time())
  839. seconds after epoch of event
  840. Examples::
  841. from torch.utils.tensorboard import SummaryWriter
  842. import numpy as np
  843. labels = np.random.randint(2, size=100) # binary label
  844. predictions = np.random.rand(100)
  845. writer = SummaryWriter()
  846. writer.add_pr_curve('pr_curve', labels, predictions, 0)
  847. writer.close()
  848. """
  849. torch._C._log_api_usage_once("tensorboard.logging.add_pr_curve")
  850. labels, predictions = make_np(labels), make_np(predictions)
  851. self._get_file_writer().add_summary(
  852. pr_curve(tag, labels, predictions, num_thresholds, weights),
  853. global_step,
  854. walltime,
  855. )
  856. def add_pr_curve_raw(
  857. self,
  858. tag,
  859. true_positive_counts,
  860. false_positive_counts,
  861. true_negative_counts,
  862. false_negative_counts,
  863. precision,
  864. recall,
  865. global_step=None,
  866. num_thresholds=127,
  867. weights=None,
  868. walltime=None,
  869. ):
  870. """Add precision recall curve with raw data.
  871. Args:
  872. tag (str): Data identifier
  873. true_positive_counts (torch.Tensor, numpy.ndarray, or string/blobname): true positive counts
  874. false_positive_counts (torch.Tensor, numpy.ndarray, or string/blobname): false positive counts
  875. true_negative_counts (torch.Tensor, numpy.ndarray, or string/blobname): true negative counts
  876. false_negative_counts (torch.Tensor, numpy.ndarray, or string/blobname): false negative counts
  877. precision (torch.Tensor, numpy.ndarray, or string/blobname): precision
  878. recall (torch.Tensor, numpy.ndarray, or string/blobname): recall
  879. global_step (int): Global step value to record
  880. num_thresholds (int): Number of thresholds used to draw the curve.
  881. walltime (float): Optional override default walltime (time.time())
  882. seconds after epoch of event
  883. see: https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/pr_curve/README.md
  884. """
  885. torch._C._log_api_usage_once("tensorboard.logging.add_pr_curve_raw")
  886. self._get_file_writer().add_summary(
  887. pr_curve_raw(
  888. tag,
  889. true_positive_counts,
  890. false_positive_counts,
  891. true_negative_counts,
  892. false_negative_counts,
  893. precision,
  894. recall,
  895. num_thresholds,
  896. weights,
  897. ),
  898. global_step,
  899. walltime,
  900. )
  901. def add_custom_scalars_multilinechart(
  902. self, tags, category="default", title="untitled"
  903. ):
  904. """Shorthand for creating multilinechart. Similar to ``add_custom_scalars()``, but the only necessary argument is *tags*.
  905. Args:
  906. tags (list): list of tags that have been used in ``add_scalar()``
  907. Examples::
  908. writer.add_custom_scalars_multilinechart(['twse/0050', 'twse/2330'])
  909. """
  910. torch._C._log_api_usage_once(
  911. "tensorboard.logging.add_custom_scalars_multilinechart"
  912. )
  913. layout = {category: {title: ["Multiline", tags]}}
  914. self._get_file_writer().add_summary(custom_scalars(layout))
  915. def add_custom_scalars_marginchart(
  916. self, tags, category="default", title="untitled"
  917. ):
  918. """Shorthand for creating marginchart.
  919. Similar to ``add_custom_scalars()``, but the only necessary argument is *tags*,
  920. which should have exactly 3 elements.
  921. Args:
  922. tags (list): list of tags that have been used in ``add_scalar()``
  923. Examples::
  924. writer.add_custom_scalars_marginchart(['twse/0050', 'twse/2330', 'twse/2006'])
  925. """
  926. torch._C._log_api_usage_once(
  927. "tensorboard.logging.add_custom_scalars_marginchart"
  928. )
  929. assert len(tags) == 3
  930. layout = {category: {title: ["Margin", tags]}}
  931. self._get_file_writer().add_summary(custom_scalars(layout))
  932. def add_custom_scalars(self, layout):
  933. """Create special chart by collecting charts tags in 'scalars'.
  934. NOTE: This function can only be called once for each SummaryWriter() object.
  935. Because it only provides metadata to tensorboard, the function can be called before or after the training loop.
  936. Args:
  937. layout (dict): {categoryName: *charts*}, where *charts* is also a dictionary
  938. {chartName: *ListOfProperties*}. The first element in *ListOfProperties* is the chart's type
  939. (one of **Multiline** or **Margin**) and the second element should be a list containing the tags
  940. you have used in add_scalar function, which will be collected into the new chart.
  941. Examples::
  942. layout = {'Taiwan':{'twse':['Multiline',['twse/0050', 'twse/2330']]},
  943. 'USA':{ 'dow':['Margin', ['dow/aaa', 'dow/bbb', 'dow/ccc']],
  944. 'nasdaq':['Margin', ['nasdaq/aaa', 'nasdaq/bbb', 'nasdaq/ccc']]}}
  945. writer.add_custom_scalars(layout)
  946. """
  947. torch._C._log_api_usage_once("tensorboard.logging.add_custom_scalars")
  948. self._get_file_writer().add_summary(custom_scalars(layout))
  949. def add_mesh(
  950. self,
  951. tag,
  952. vertices,
  953. colors=None,
  954. faces=None,
  955. config_dict=None,
  956. global_step=None,
  957. walltime=None,
  958. ):
  959. """Add meshes or 3D point clouds to TensorBoard.
  960. The visualization is based on Three.js,
  961. so it allows users to interact with the rendered object. Besides the basic definitions
  962. such as vertices, faces, users can further provide camera parameter, lighting condition, etc.
  963. Please check https://threejs.org/docs/index.html#manual/en/introduction/Creating-a-scene for
  964. advanced usage.
  965. Args:
  966. tag (str): Data identifier
  967. vertices (torch.Tensor): List of the 3D coordinates of vertices.
  968. colors (torch.Tensor): Colors for each vertex
  969. faces (torch.Tensor): Indices of vertices within each triangle. (Optional)
  970. config_dict: Dictionary with ThreeJS classes names and configuration.
  971. global_step (int): Global step value to record
  972. walltime (float): Optional override default walltime (time.time())
  973. seconds after epoch of event
  974. Shape:
  975. vertices: :math:`(B, N, 3)`. (batch, number_of_vertices, channels)
  976. colors: :math:`(B, N, 3)`. The values should lie in [0, 255] for type `uint8` or [0, 1] for type `float`.
  977. faces: :math:`(B, N, 3)`. The values should lie in [0, number_of_vertices] for type `uint8`.
  978. Examples::
  979. from torch.utils.tensorboard import SummaryWriter
  980. vertices_tensor = torch.as_tensor([
  981. [1, 1, 1],
  982. [-1, -1, 1],
  983. [1, -1, -1],
  984. [-1, 1, -1],
  985. ], dtype=torch.float).unsqueeze(0)
  986. colors_tensor = torch.as_tensor([
  987. [255, 0, 0],
  988. [0, 255, 0],
  989. [0, 0, 255],
  990. [255, 0, 255],
  991. ], dtype=torch.int).unsqueeze(0)
  992. faces_tensor = torch.as_tensor([
  993. [0, 2, 3],
  994. [0, 3, 1],
  995. [0, 1, 2],
  996. [1, 3, 2],
  997. ], dtype=torch.int).unsqueeze(0)
  998. writer = SummaryWriter()
  999. writer.add_mesh('my_mesh', vertices=vertices_tensor, colors=colors_tensor, faces=faces_tensor)
  1000. writer.close()
  1001. """
  1002. torch._C._log_api_usage_once("tensorboard.logging.add_mesh")
  1003. self._get_file_writer().add_summary(
  1004. mesh(tag, vertices, colors, faces, config_dict), global_step, walltime
  1005. )
  1006. def flush(self):
  1007. """Flushes the event file to disk.
  1008. Call this method to make sure that all pending events have been written to
  1009. disk.
  1010. """
  1011. if self.all_writers is None:
  1012. return
  1013. for writer in self.all_writers.values():
  1014. writer.flush()
  1015. def close(self):
  1016. if self.all_writers is None:
  1017. return # ignore double close
  1018. for writer in self.all_writers.values():
  1019. writer.flush()
  1020. writer.close()
  1021. self.file_writer = self.all_writers = None
  1022. def __enter__(self):
  1023. return self
  1024. def __exit__(self, exc_type, exc_val, exc_tb):
  1025. self.close()