summary.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982
  1. # mypy: allow-untyped-defs
  2. import json
  3. import logging
  4. import os
  5. import struct
  6. from typing import Any, List, Optional
  7. import torch
  8. import numpy as np
  9. from google.protobuf import struct_pb2
  10. from tensorboard.compat.proto.summary_pb2 import (
  11. HistogramProto,
  12. Summary,
  13. SummaryMetadata,
  14. )
  15. from tensorboard.compat.proto.tensor_pb2 import TensorProto
  16. from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
  17. from tensorboard.plugins.custom_scalar import layout_pb2
  18. from tensorboard.plugins.pr_curve.plugin_data_pb2 import PrCurvePluginData
  19. from tensorboard.plugins.text.plugin_data_pb2 import TextPluginData
  20. from ._convert_np import make_np
  21. from ._utils import _prepare_video, convert_to_HWC
  22. __all__ = [
  23. "half_to_int",
  24. "int_to_half",
  25. "hparams",
  26. "scalar",
  27. "histogram_raw",
  28. "histogram",
  29. "make_histogram",
  30. "image",
  31. "image_boxes",
  32. "draw_boxes",
  33. "make_image",
  34. "video",
  35. "make_video",
  36. "audio",
  37. "custom_scalars",
  38. "text",
  39. "tensor_proto",
  40. "pr_curve_raw",
  41. "pr_curve",
  42. "compute_curve",
  43. "mesh",
  44. ]
  45. logger = logging.getLogger(__name__)
  46. def half_to_int(f: float) -> int:
  47. """Casts a half-precision float value into an integer.
  48. Converts a half precision floating point value, such as `torch.half` or
  49. `torch.bfloat16`, into an integer value which can be written into the
  50. half_val field of a TensorProto for storage.
  51. To undo the effects of this conversion, use int_to_half().
  52. """
  53. buf = struct.pack("f", f)
  54. return struct.unpack("i", buf)[0]
  55. def int_to_half(i: int) -> float:
  56. """Casts an integer value to a half-precision float.
  57. Converts an integer value obtained from half_to_int back into a floating
  58. point value.
  59. """
  60. buf = struct.pack("i", i)
  61. return struct.unpack("f", buf)[0]
  62. def _tensor_to_half_val(t: torch.Tensor) -> List[int]:
  63. return [half_to_int(x) for x in t.flatten().tolist()]
  64. def _tensor_to_complex_val(t: torch.Tensor) -> List[float]:
  65. return torch.view_as_real(t).flatten().tolist()
  66. def _tensor_to_list(t: torch.Tensor) -> List[Any]:
  67. return t.flatten().tolist()
  68. # type maps: torch.Tensor type -> (protobuf type, protobuf val field)
  69. _TENSOR_TYPE_MAP = {
  70. torch.half: ("DT_HALF", "half_val", _tensor_to_half_val),
  71. torch.float16: ("DT_HALF", "half_val", _tensor_to_half_val),
  72. torch.bfloat16: ("DT_BFLOAT16", "half_val", _tensor_to_half_val),
  73. torch.float32: ("DT_FLOAT", "float_val", _tensor_to_list),
  74. torch.float: ("DT_FLOAT", "float_val", _tensor_to_list),
  75. torch.float64: ("DT_DOUBLE", "double_val", _tensor_to_list),
  76. torch.double: ("DT_DOUBLE", "double_val", _tensor_to_list),
  77. torch.int8: ("DT_INT8", "int_val", _tensor_to_list),
  78. torch.uint8: ("DT_UINT8", "int_val", _tensor_to_list),
  79. torch.qint8: ("DT_UINT8", "int_val", _tensor_to_list),
  80. torch.int16: ("DT_INT16", "int_val", _tensor_to_list),
  81. torch.short: ("DT_INT16", "int_val", _tensor_to_list),
  82. torch.int: ("DT_INT32", "int_val", _tensor_to_list),
  83. torch.int32: ("DT_INT32", "int_val", _tensor_to_list),
  84. torch.qint32: ("DT_INT32", "int_val", _tensor_to_list),
  85. torch.int64: ("DT_INT64", "int64_val", _tensor_to_list),
  86. torch.complex32: ("DT_COMPLEX32", "scomplex_val", _tensor_to_complex_val),
  87. torch.chalf: ("DT_COMPLEX32", "scomplex_val", _tensor_to_complex_val),
  88. torch.complex64: ("DT_COMPLEX64", "scomplex_val", _tensor_to_complex_val),
  89. torch.cfloat: ("DT_COMPLEX64", "scomplex_val", _tensor_to_complex_val),
  90. torch.bool: ("DT_BOOL", "bool_val", _tensor_to_list),
  91. torch.complex128: ("DT_COMPLEX128", "dcomplex_val", _tensor_to_complex_val),
  92. torch.cdouble: ("DT_COMPLEX128", "dcomplex_val", _tensor_to_complex_val),
  93. torch.uint8: ("DT_UINT8", "uint32_val", _tensor_to_list),
  94. torch.quint8: ("DT_UINT8", "uint32_val", _tensor_to_list),
  95. torch.quint4x2: ("DT_UINT8", "uint32_val", _tensor_to_list),
  96. }
  97. def _calc_scale_factor(tensor):
  98. converted = tensor.numpy() if not isinstance(tensor, np.ndarray) else tensor
  99. return 1 if converted.dtype == np.uint8 else 255
  100. def _draw_single_box(
  101. image,
  102. xmin,
  103. ymin,
  104. xmax,
  105. ymax,
  106. display_str,
  107. color="black",
  108. color_text="black",
  109. thickness=2,
  110. ):
  111. from PIL import ImageDraw, ImageFont
  112. font = ImageFont.load_default()
  113. draw = ImageDraw.Draw(image)
  114. (left, right, top, bottom) = (xmin, xmax, ymin, ymax)
  115. draw.line(
  116. [(left, top), (left, bottom), (right, bottom), (right, top), (left, top)],
  117. width=thickness,
  118. fill=color,
  119. )
  120. if display_str:
  121. text_bottom = bottom
  122. # Reverse list and print from bottom to top.
  123. _left, _top, _right, _bottom = font.getbbox(display_str)
  124. text_width, text_height = _right - _left, _bottom - _top
  125. margin = np.ceil(0.05 * text_height)
  126. draw.rectangle(
  127. [
  128. (left, text_bottom - text_height - 2 * margin),
  129. (left + text_width, text_bottom),
  130. ],
  131. fill=color,
  132. )
  133. draw.text(
  134. (left + margin, text_bottom - text_height - margin),
  135. display_str,
  136. fill=color_text,
  137. font=font,
  138. )
  139. return image
  140. def hparams(hparam_dict=None, metric_dict=None, hparam_domain_discrete=None):
  141. """Output three `Summary` protocol buffers needed by hparams plugin.
  142. `Experiment` keeps the metadata of an experiment, such as the name of the
  143. hyperparameters and the name of the metrics.
  144. `SessionStartInfo` keeps key-value pairs of the hyperparameters
  145. `SessionEndInfo` describes status of the experiment e.g. STATUS_SUCCESS
  146. Args:
  147. hparam_dict: A dictionary that contains names of the hyperparameters
  148. and their values.
  149. metric_dict: A dictionary that contains names of the metrics
  150. and their values.
  151. hparam_domain_discrete: (Optional[Dict[str, List[Any]]]) A dictionary that
  152. contains names of the hyperparameters and all discrete values they can hold
  153. Returns:
  154. The `Summary` protobufs for Experiment, SessionStartInfo and
  155. SessionEndInfo
  156. """
  157. import torch
  158. from tensorboard.plugins.hparams.api_pb2 import (
  159. DataType,
  160. Experiment,
  161. HParamInfo,
  162. MetricInfo,
  163. MetricName,
  164. Status,
  165. )
  166. from tensorboard.plugins.hparams.metadata import (
  167. EXPERIMENT_TAG,
  168. PLUGIN_DATA_VERSION,
  169. PLUGIN_NAME,
  170. SESSION_END_INFO_TAG,
  171. SESSION_START_INFO_TAG,
  172. )
  173. from tensorboard.plugins.hparams.plugin_data_pb2 import (
  174. HParamsPluginData,
  175. SessionEndInfo,
  176. SessionStartInfo,
  177. )
  178. # TODO: expose other parameters in the future.
  179. # hp = HParamInfo(name='lr',display_name='learning rate',
  180. # type=DataType.DATA_TYPE_FLOAT64, domain_interval=Interval(min_value=10,
  181. # max_value=100))
  182. # mt = MetricInfo(name=MetricName(tag='accuracy'), display_name='accuracy',
  183. # description='', dataset_type=DatasetType.DATASET_VALIDATION)
  184. # exp = Experiment(name='123', description='456', time_created_secs=100.0,
  185. # hparam_infos=[hp], metric_infos=[mt], user='tw')
  186. if not isinstance(hparam_dict, dict):
  187. logger.warning("parameter: hparam_dict should be a dictionary, nothing logged.")
  188. raise TypeError(
  189. "parameter: hparam_dict should be a dictionary, nothing logged."
  190. )
  191. if not isinstance(metric_dict, dict):
  192. logger.warning("parameter: metric_dict should be a dictionary, nothing logged.")
  193. raise TypeError(
  194. "parameter: metric_dict should be a dictionary, nothing logged."
  195. )
  196. hparam_domain_discrete = hparam_domain_discrete or {}
  197. if not isinstance(hparam_domain_discrete, dict):
  198. raise TypeError(
  199. "parameter: hparam_domain_discrete should be a dictionary, nothing logged."
  200. )
  201. for k, v in hparam_domain_discrete.items():
  202. if (
  203. k not in hparam_dict
  204. or not isinstance(v, list)
  205. or not all(isinstance(d, type(hparam_dict[k])) for d in v)
  206. ):
  207. raise TypeError(
  208. f"parameter: hparam_domain_discrete[{k}] should be a list of same type as hparam_dict[{k}]."
  209. )
  210. hps = []
  211. ssi = SessionStartInfo()
  212. for k, v in hparam_dict.items():
  213. if v is None:
  214. continue
  215. if isinstance(v, (int, float)):
  216. ssi.hparams[k].number_value = v
  217. if k in hparam_domain_discrete:
  218. domain_discrete: Optional[struct_pb2.ListValue] = struct_pb2.ListValue(
  219. values=[
  220. struct_pb2.Value(number_value=d)
  221. for d in hparam_domain_discrete[k]
  222. ]
  223. )
  224. else:
  225. domain_discrete = None
  226. hps.append(
  227. HParamInfo(
  228. name=k,
  229. type=DataType.Value("DATA_TYPE_FLOAT64"),
  230. domain_discrete=domain_discrete,
  231. )
  232. )
  233. continue
  234. if isinstance(v, str):
  235. ssi.hparams[k].string_value = v
  236. if k in hparam_domain_discrete:
  237. domain_discrete = struct_pb2.ListValue(
  238. values=[
  239. struct_pb2.Value(string_value=d)
  240. for d in hparam_domain_discrete[k]
  241. ]
  242. )
  243. else:
  244. domain_discrete = None
  245. hps.append(
  246. HParamInfo(
  247. name=k,
  248. type=DataType.Value("DATA_TYPE_STRING"),
  249. domain_discrete=domain_discrete,
  250. )
  251. )
  252. continue
  253. if isinstance(v, bool):
  254. ssi.hparams[k].bool_value = v
  255. if k in hparam_domain_discrete:
  256. domain_discrete = struct_pb2.ListValue(
  257. values=[
  258. struct_pb2.Value(bool_value=d)
  259. for d in hparam_domain_discrete[k]
  260. ]
  261. )
  262. else:
  263. domain_discrete = None
  264. hps.append(
  265. HParamInfo(
  266. name=k,
  267. type=DataType.Value("DATA_TYPE_BOOL"),
  268. domain_discrete=domain_discrete,
  269. )
  270. )
  271. continue
  272. if isinstance(v, torch.Tensor):
  273. v = make_np(v)[0]
  274. ssi.hparams[k].number_value = v
  275. hps.append(HParamInfo(name=k, type=DataType.Value("DATA_TYPE_FLOAT64")))
  276. continue
  277. raise ValueError(
  278. "value should be one of int, float, str, bool, or torch.Tensor"
  279. )
  280. content = HParamsPluginData(session_start_info=ssi, version=PLUGIN_DATA_VERSION)
  281. smd = SummaryMetadata(
  282. plugin_data=SummaryMetadata.PluginData(
  283. plugin_name=PLUGIN_NAME, content=content.SerializeToString()
  284. )
  285. )
  286. ssi = Summary(value=[Summary.Value(tag=SESSION_START_INFO_TAG, metadata=smd)])
  287. mts = [MetricInfo(name=MetricName(tag=k)) for k in metric_dict.keys()]
  288. exp = Experiment(hparam_infos=hps, metric_infos=mts)
  289. content = HParamsPluginData(experiment=exp, version=PLUGIN_DATA_VERSION)
  290. smd = SummaryMetadata(
  291. plugin_data=SummaryMetadata.PluginData(
  292. plugin_name=PLUGIN_NAME, content=content.SerializeToString()
  293. )
  294. )
  295. exp = Summary(value=[Summary.Value(tag=EXPERIMENT_TAG, metadata=smd)])
  296. sei = SessionEndInfo(status=Status.Value("STATUS_SUCCESS"))
  297. content = HParamsPluginData(session_end_info=sei, version=PLUGIN_DATA_VERSION)
  298. smd = SummaryMetadata(
  299. plugin_data=SummaryMetadata.PluginData(
  300. plugin_name=PLUGIN_NAME, content=content.SerializeToString()
  301. )
  302. )
  303. sei = Summary(value=[Summary.Value(tag=SESSION_END_INFO_TAG, metadata=smd)])
  304. return exp, ssi, sei
  305. def scalar(name, tensor, collections=None, new_style=False, double_precision=False):
  306. """Output a `Summary` protocol buffer containing a single scalar value.
  307. The generated Summary has a Tensor.proto containing the input Tensor.
  308. Args:
  309. name: A name for the generated node. Will also serve as the series name in
  310. TensorBoard.
  311. tensor: A real numeric Tensor containing a single value.
  312. collections: Optional list of graph collections keys. The new summary op is
  313. added to these collections. Defaults to `[GraphKeys.SUMMARIES]`.
  314. new_style: Whether to use new style (tensor field) or old style (simple_value
  315. field). New style could lead to faster data loading.
  316. Returns:
  317. A scalar `Tensor` of type `string`. Which contains a `Summary` protobuf.
  318. Raises:
  319. ValueError: If tensor has the wrong shape or type.
  320. """
  321. tensor = make_np(tensor).squeeze()
  322. assert (
  323. tensor.ndim == 0
  324. ), f"Tensor should contain one element (0 dimensions). Was given size: {tensor.size} and {tensor.ndim} dimensions."
  325. # python float is double precision in numpy
  326. scalar = float(tensor)
  327. if new_style:
  328. tensor_proto = TensorProto(float_val=[scalar], dtype="DT_FLOAT")
  329. if double_precision:
  330. tensor_proto = TensorProto(double_val=[scalar], dtype="DT_DOUBLE")
  331. plugin_data = SummaryMetadata.PluginData(plugin_name="scalars")
  332. smd = SummaryMetadata(plugin_data=plugin_data)
  333. return Summary(
  334. value=[
  335. Summary.Value(
  336. tag=name,
  337. tensor=tensor_proto,
  338. metadata=smd,
  339. )
  340. ]
  341. )
  342. else:
  343. return Summary(value=[Summary.Value(tag=name, simple_value=scalar)])
  344. def tensor_proto(tag, tensor):
  345. """Outputs a `Summary` protocol buffer containing the full tensor.
  346. The generated Summary has a Tensor.proto containing the input Tensor.
  347. Args:
  348. name: A name for the generated node. Will also serve as the series name in
  349. TensorBoard.
  350. tensor: Tensor to be converted to protobuf
  351. Returns:
  352. A tensor protobuf in a `Summary` protobuf.
  353. Raises:
  354. ValueError: If tensor is too big to be converted to protobuf, or
  355. tensor data type is not supported
  356. """
  357. if tensor.numel() * tensor.itemsize >= (1 << 31):
  358. raise ValueError(
  359. "tensor is bigger than protocol buffer's hard limit of 2GB in size"
  360. )
  361. if tensor.dtype in _TENSOR_TYPE_MAP:
  362. dtype, field_name, conversion_fn = _TENSOR_TYPE_MAP[tensor.dtype]
  363. tensor_proto = TensorProto(
  364. **{
  365. "dtype": dtype,
  366. "tensor_shape": TensorShapeProto(
  367. dim=[TensorShapeProto.Dim(size=x) for x in tensor.shape]
  368. ),
  369. field_name: conversion_fn(tensor),
  370. },
  371. )
  372. else:
  373. raise ValueError(f"{tag} has unsupported tensor dtype {tensor.dtype}")
  374. plugin_data = SummaryMetadata.PluginData(plugin_name="tensor")
  375. smd = SummaryMetadata(plugin_data=plugin_data)
  376. return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor_proto)])
  377. def histogram_raw(name, min, max, num, sum, sum_squares, bucket_limits, bucket_counts):
  378. # pylint: disable=line-too-long
  379. """Output a `Summary` protocol buffer with a histogram.
  380. The generated
  381. [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto)
  382. has one summary value containing a histogram for `values`.
  383. Args:
  384. name: A name for the generated node. Will also serve as a series name in
  385. TensorBoard.
  386. min: A float or int min value
  387. max: A float or int max value
  388. num: Int number of values
  389. sum: Float or int sum of all values
  390. sum_squares: Float or int sum of squares for all values
  391. bucket_limits: A numeric `Tensor` with upper value per bucket
  392. bucket_counts: A numeric `Tensor` with number of values per bucket
  393. Returns:
  394. A scalar `Tensor` of type `string`. The serialized `Summary` protocol
  395. buffer.
  396. """
  397. hist = HistogramProto(
  398. min=min,
  399. max=max,
  400. num=num,
  401. sum=sum,
  402. sum_squares=sum_squares,
  403. bucket_limit=bucket_limits,
  404. bucket=bucket_counts,
  405. )
  406. return Summary(value=[Summary.Value(tag=name, histo=hist)])
  407. def histogram(name, values, bins, max_bins=None):
  408. # pylint: disable=line-too-long
  409. """Output a `Summary` protocol buffer with a histogram.
  410. The generated
  411. [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto)
  412. has one summary value containing a histogram for `values`.
  413. This op reports an `InvalidArgument` error if any value is not finite.
  414. Args:
  415. name: A name for the generated node. Will also serve as a series name in
  416. TensorBoard.
  417. values: A real numeric `Tensor`. Any shape. Values to use to
  418. build the histogram.
  419. Returns:
  420. A scalar `Tensor` of type `string`. The serialized `Summary` protocol
  421. buffer.
  422. """
  423. values = make_np(values)
  424. hist = make_histogram(values.astype(float), bins, max_bins)
  425. return Summary(value=[Summary.Value(tag=name, histo=hist)])
  426. def make_histogram(values, bins, max_bins=None):
  427. """Convert values into a histogram proto using logic from histogram.cc."""
  428. if values.size == 0:
  429. raise ValueError("The input has no element.")
  430. values = values.reshape(-1)
  431. counts, limits = np.histogram(values, bins=bins)
  432. num_bins = len(counts)
  433. if max_bins is not None and num_bins > max_bins:
  434. subsampling = num_bins // max_bins
  435. subsampling_remainder = num_bins % subsampling
  436. if subsampling_remainder != 0:
  437. counts = np.pad(
  438. counts,
  439. pad_width=[[0, subsampling - subsampling_remainder]],
  440. mode="constant",
  441. constant_values=0,
  442. )
  443. counts = counts.reshape(-1, subsampling).sum(axis=-1)
  444. new_limits = np.empty((counts.size + 1,), limits.dtype)
  445. new_limits[:-1] = limits[:-1:subsampling]
  446. new_limits[-1] = limits[-1]
  447. limits = new_limits
  448. # Find the first and the last bin defining the support of the histogram:
  449. cum_counts = np.cumsum(np.greater(counts, 0))
  450. start, end = np.searchsorted(cum_counts, [0, cum_counts[-1] - 1], side="right")
  451. start = int(start)
  452. end = int(end) + 1
  453. del cum_counts
  454. # TensorBoard only includes the right bin limits. To still have the leftmost limit
  455. # included, we include an empty bin left.
  456. # If start == 0, we need to add an empty one left, otherwise we can just include the bin left to the
  457. # first nonzero-count bin:
  458. counts = (
  459. counts[start - 1 : end] if start > 0 else np.concatenate([[0], counts[:end]])
  460. )
  461. limits = limits[start : end + 1]
  462. if counts.size == 0 or limits.size == 0:
  463. raise ValueError("The histogram is empty, please file a bug report.")
  464. sum_sq = values.dot(values)
  465. return HistogramProto(
  466. min=values.min(),
  467. max=values.max(),
  468. num=len(values),
  469. sum=values.sum(),
  470. sum_squares=sum_sq,
  471. bucket_limit=limits.tolist(),
  472. bucket=counts.tolist(),
  473. )
  474. def image(tag, tensor, rescale=1, dataformats="NCHW"):
  475. """Output a `Summary` protocol buffer with images.
  476. The summary has up to `max_images` summary values containing images. The
  477. images are built from `tensor` which must be 3-D with shape `[height, width,
  478. channels]` and where `channels` can be:
  479. * 1: `tensor` is interpreted as Grayscale.
  480. * 3: `tensor` is interpreted as RGB.
  481. * 4: `tensor` is interpreted as RGBA.
  482. The `name` in the outputted Summary.Value protobufs is generated based on the
  483. name, with a suffix depending on the max_outputs setting:
  484. * If `max_outputs` is 1, the summary value tag is '*name*/image'.
  485. * If `max_outputs` is greater than 1, the summary value tags are
  486. generated sequentially as '*name*/image/0', '*name*/image/1', etc.
  487. Args:
  488. tag: A name for the generated node. Will also serve as a series name in
  489. TensorBoard.
  490. tensor: A 3-D `uint8` or `float32` `Tensor` of shape `[height, width,
  491. channels]` where `channels` is 1, 3, or 4.
  492. 'tensor' can either have values in [0, 1] (float32) or [0, 255] (uint8).
  493. The image() function will scale the image values to [0, 255] by applying
  494. a scale factor of either 1 (uint8) or 255 (float32). Out-of-range values
  495. will be clipped.
  496. Returns:
  497. A scalar `Tensor` of type `string`. The serialized `Summary` protocol
  498. buffer.
  499. """
  500. tensor = make_np(tensor)
  501. tensor = convert_to_HWC(tensor, dataformats)
  502. # Do not assume that user passes in values in [0, 255], use data type to detect
  503. scale_factor = _calc_scale_factor(tensor)
  504. tensor = tensor.astype(np.float32)
  505. tensor = (tensor * scale_factor).clip(0, 255).astype(np.uint8)
  506. image = make_image(tensor, rescale=rescale)
  507. return Summary(value=[Summary.Value(tag=tag, image=image)])
  508. def image_boxes(
  509. tag, tensor_image, tensor_boxes, rescale=1, dataformats="CHW", labels=None
  510. ):
  511. """Output a `Summary` protocol buffer with images."""
  512. tensor_image = make_np(tensor_image)
  513. tensor_image = convert_to_HWC(tensor_image, dataformats)
  514. tensor_boxes = make_np(tensor_boxes)
  515. tensor_image = tensor_image.astype(np.float32) * _calc_scale_factor(tensor_image)
  516. image = make_image(
  517. tensor_image.clip(0, 255).astype(np.uint8),
  518. rescale=rescale,
  519. rois=tensor_boxes,
  520. labels=labels,
  521. )
  522. return Summary(value=[Summary.Value(tag=tag, image=image)])
  523. def draw_boxes(disp_image, boxes, labels=None):
  524. # xyxy format
  525. num_boxes = boxes.shape[0]
  526. list_gt = range(num_boxes)
  527. for i in list_gt:
  528. disp_image = _draw_single_box(
  529. disp_image,
  530. boxes[i, 0],
  531. boxes[i, 1],
  532. boxes[i, 2],
  533. boxes[i, 3],
  534. display_str=None if labels is None else labels[i],
  535. color="Red",
  536. )
  537. return disp_image
  538. def make_image(tensor, rescale=1, rois=None, labels=None):
  539. """Convert a numpy representation of an image to Image protobuf."""
  540. from PIL import Image
  541. height, width, channel = tensor.shape
  542. scaled_height = int(height * rescale)
  543. scaled_width = int(width * rescale)
  544. image = Image.fromarray(tensor)
  545. if rois is not None:
  546. image = draw_boxes(image, rois, labels=labels)
  547. ANTIALIAS = Image.Resampling.LANCZOS
  548. image = image.resize((scaled_width, scaled_height), ANTIALIAS)
  549. import io
  550. output = io.BytesIO()
  551. image.save(output, format="PNG")
  552. image_string = output.getvalue()
  553. output.close()
  554. return Summary.Image(
  555. height=height,
  556. width=width,
  557. colorspace=channel,
  558. encoded_image_string=image_string,
  559. )
  560. def video(tag, tensor, fps=4):
  561. tensor = make_np(tensor)
  562. tensor = _prepare_video(tensor)
  563. # If user passes in uint8, then we don't need to rescale by 255
  564. scale_factor = _calc_scale_factor(tensor)
  565. tensor = tensor.astype(np.float32)
  566. tensor = (tensor * scale_factor).clip(0, 255).astype(np.uint8)
  567. video = make_video(tensor, fps)
  568. return Summary(value=[Summary.Value(tag=tag, image=video)])
  569. def make_video(tensor, fps):
  570. try:
  571. import moviepy # noqa: F401
  572. except ImportError:
  573. print("add_video needs package moviepy")
  574. return
  575. try:
  576. from moviepy import editor as mpy
  577. except ImportError:
  578. print(
  579. "moviepy is installed, but can't import moviepy.editor.",
  580. "Some packages could be missing [imageio, requests]",
  581. )
  582. return
  583. import tempfile
  584. t, h, w, c = tensor.shape
  585. # encode sequence of images into gif string
  586. clip = mpy.ImageSequenceClip(list(tensor), fps=fps)
  587. filename = tempfile.NamedTemporaryFile(suffix=".gif", delete=False).name
  588. try: # newer version of moviepy use logger instead of progress_bar argument.
  589. clip.write_gif(filename, verbose=False, logger=None)
  590. except TypeError:
  591. try: # older version of moviepy does not support progress_bar argument.
  592. clip.write_gif(filename, verbose=False, progress_bar=False)
  593. except TypeError:
  594. clip.write_gif(filename, verbose=False)
  595. with open(filename, "rb") as f:
  596. tensor_string = f.read()
  597. try:
  598. os.remove(filename)
  599. except OSError:
  600. logger.warning("The temporary file used by moviepy cannot be deleted.")
  601. return Summary.Image(
  602. height=h, width=w, colorspace=c, encoded_image_string=tensor_string
  603. )
  604. def audio(tag, tensor, sample_rate=44100):
  605. array = make_np(tensor)
  606. array = array.squeeze()
  607. if abs(array).max() > 1:
  608. print("warning: audio amplitude out of range, auto clipped.")
  609. array = array.clip(-1, 1)
  610. assert array.ndim == 1, "input tensor should be 1 dimensional."
  611. array = (array * np.iinfo(np.int16).max).astype("<i2")
  612. import io
  613. import wave
  614. fio = io.BytesIO()
  615. with wave.open(fio, "wb") as wave_write:
  616. wave_write.setnchannels(1)
  617. wave_write.setsampwidth(2)
  618. wave_write.setframerate(sample_rate)
  619. wave_write.writeframes(array.data)
  620. audio_string = fio.getvalue()
  621. fio.close()
  622. audio = Summary.Audio(
  623. sample_rate=sample_rate,
  624. num_channels=1,
  625. length_frames=array.shape[-1],
  626. encoded_audio_string=audio_string,
  627. content_type="audio/wav",
  628. )
  629. return Summary(value=[Summary.Value(tag=tag, audio=audio)])
  630. def custom_scalars(layout):
  631. categories = []
  632. for k, v in layout.items():
  633. charts = []
  634. for chart_name, chart_meatadata in v.items():
  635. tags = chart_meatadata[1]
  636. if chart_meatadata[0] == "Margin":
  637. assert len(tags) == 3
  638. mgcc = layout_pb2.MarginChartContent(
  639. series=[
  640. layout_pb2.MarginChartContent.Series(
  641. value=tags[0], lower=tags[1], upper=tags[2]
  642. )
  643. ]
  644. )
  645. chart = layout_pb2.Chart(title=chart_name, margin=mgcc)
  646. else:
  647. mlcc = layout_pb2.MultilineChartContent(tag=tags)
  648. chart = layout_pb2.Chart(title=chart_name, multiline=mlcc)
  649. charts.append(chart)
  650. categories.append(layout_pb2.Category(title=k, chart=charts))
  651. layout = layout_pb2.Layout(category=categories)
  652. plugin_data = SummaryMetadata.PluginData(plugin_name="custom_scalars")
  653. smd = SummaryMetadata(plugin_data=plugin_data)
  654. tensor = TensorProto(
  655. dtype="DT_STRING",
  656. string_val=[layout.SerializeToString()],
  657. tensor_shape=TensorShapeProto(),
  658. )
  659. return Summary(
  660. value=[
  661. Summary.Value(tag="custom_scalars__config__", tensor=tensor, metadata=smd)
  662. ]
  663. )
  664. def text(tag, text):
  665. plugin_data = SummaryMetadata.PluginData(
  666. plugin_name="text", content=TextPluginData(version=0).SerializeToString()
  667. )
  668. smd = SummaryMetadata(plugin_data=plugin_data)
  669. tensor = TensorProto(
  670. dtype="DT_STRING",
  671. string_val=[text.encode(encoding="utf_8")],
  672. tensor_shape=TensorShapeProto(dim=[TensorShapeProto.Dim(size=1)]),
  673. )
  674. return Summary(
  675. value=[Summary.Value(tag=tag + "/text_summary", metadata=smd, tensor=tensor)]
  676. )
  677. def pr_curve_raw(
  678. tag, tp, fp, tn, fn, precision, recall, num_thresholds=127, weights=None
  679. ):
  680. if num_thresholds > 127: # weird, value > 127 breaks protobuf
  681. num_thresholds = 127
  682. data = np.stack((tp, fp, tn, fn, precision, recall))
  683. pr_curve_plugin_data = PrCurvePluginData(
  684. version=0, num_thresholds=num_thresholds
  685. ).SerializeToString()
  686. plugin_data = SummaryMetadata.PluginData(
  687. plugin_name="pr_curves", content=pr_curve_plugin_data
  688. )
  689. smd = SummaryMetadata(plugin_data=plugin_data)
  690. tensor = TensorProto(
  691. dtype="DT_FLOAT",
  692. float_val=data.reshape(-1).tolist(),
  693. tensor_shape=TensorShapeProto(
  694. dim=[
  695. TensorShapeProto.Dim(size=data.shape[0]),
  696. TensorShapeProto.Dim(size=data.shape[1]),
  697. ]
  698. ),
  699. )
  700. return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)])
  701. def pr_curve(tag, labels, predictions, num_thresholds=127, weights=None):
  702. # weird, value > 127 breaks protobuf
  703. num_thresholds = min(num_thresholds, 127)
  704. data = compute_curve(
  705. labels, predictions, num_thresholds=num_thresholds, weights=weights
  706. )
  707. pr_curve_plugin_data = PrCurvePluginData(
  708. version=0, num_thresholds=num_thresholds
  709. ).SerializeToString()
  710. plugin_data = SummaryMetadata.PluginData(
  711. plugin_name="pr_curves", content=pr_curve_plugin_data
  712. )
  713. smd = SummaryMetadata(plugin_data=plugin_data)
  714. tensor = TensorProto(
  715. dtype="DT_FLOAT",
  716. float_val=data.reshape(-1).tolist(),
  717. tensor_shape=TensorShapeProto(
  718. dim=[
  719. TensorShapeProto.Dim(size=data.shape[0]),
  720. TensorShapeProto.Dim(size=data.shape[1]),
  721. ]
  722. ),
  723. )
  724. return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)])
  725. # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/pr_curve/summary.py
  726. def compute_curve(labels, predictions, num_thresholds=None, weights=None):
  727. _MINIMUM_COUNT = 1e-7
  728. if weights is None:
  729. weights = 1.0
  730. # Compute bins of true positives and false positives.
  731. bucket_indices = np.int32(np.floor(predictions * (num_thresholds - 1)))
  732. float_labels = labels.astype(np.float64)
  733. histogram_range = (0, num_thresholds - 1)
  734. tp_buckets, _ = np.histogram(
  735. bucket_indices,
  736. bins=num_thresholds,
  737. range=histogram_range,
  738. weights=float_labels * weights,
  739. )
  740. fp_buckets, _ = np.histogram(
  741. bucket_indices,
  742. bins=num_thresholds,
  743. range=histogram_range,
  744. weights=(1.0 - float_labels) * weights,
  745. )
  746. # Obtain the reverse cumulative sum.
  747. tp = np.cumsum(tp_buckets[::-1])[::-1]
  748. fp = np.cumsum(fp_buckets[::-1])[::-1]
  749. tn = fp[0] - fp
  750. fn = tp[0] - tp
  751. precision = tp / np.maximum(_MINIMUM_COUNT, tp + fp)
  752. recall = tp / np.maximum(_MINIMUM_COUNT, tp + fn)
  753. return np.stack((tp, fp, tn, fn, precision, recall))
  754. def _get_tensor_summary(
  755. name, display_name, description, tensor, content_type, components, json_config
  756. ):
  757. """Create a tensor summary with summary metadata.
  758. Args:
  759. name: Uniquely identifiable name of the summary op. Could be replaced by
  760. combination of name and type to make it unique even outside of this
  761. summary.
  762. display_name: Will be used as the display name in TensorBoard.
  763. Defaults to `name`.
  764. description: A longform readable description of the summary data. Markdown
  765. is supported.
  766. tensor: Tensor to display in summary.
  767. content_type: Type of content inside the Tensor.
  768. components: Bitmask representing present parts (vertices, colors, etc.) that
  769. belong to the summary.
  770. json_config: A string, JSON-serialized dictionary of ThreeJS classes
  771. configuration.
  772. Returns:
  773. Tensor summary with metadata.
  774. """
  775. import torch
  776. from tensorboard.plugins.mesh import metadata
  777. tensor = torch.as_tensor(tensor)
  778. tensor_metadata = metadata.create_summary_metadata(
  779. name,
  780. display_name,
  781. content_type,
  782. components,
  783. tensor.shape,
  784. description,
  785. json_config=json_config,
  786. )
  787. tensor = TensorProto(
  788. dtype="DT_FLOAT",
  789. float_val=tensor.reshape(-1).tolist(),
  790. tensor_shape=TensorShapeProto(
  791. dim=[
  792. TensorShapeProto.Dim(size=tensor.shape[0]),
  793. TensorShapeProto.Dim(size=tensor.shape[1]),
  794. TensorShapeProto.Dim(size=tensor.shape[2]),
  795. ]
  796. ),
  797. )
  798. tensor_summary = Summary.Value(
  799. tag=metadata.get_instance_name(name, content_type),
  800. tensor=tensor,
  801. metadata=tensor_metadata,
  802. )
  803. return tensor_summary
  804. def _get_json_config(config_dict):
  805. """Parse and returns JSON string from python dictionary."""
  806. json_config = "{}"
  807. if config_dict is not None:
  808. json_config = json.dumps(config_dict, sort_keys=True)
  809. return json_config
  810. # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/mesh/summary.py
  811. def mesh(
  812. tag, vertices, colors, faces, config_dict, display_name=None, description=None
  813. ):
  814. """Output a merged `Summary` protocol buffer with a mesh/point cloud.
  815. Args:
  816. tag: A name for this summary operation.
  817. vertices: Tensor of shape `[dim_1, ..., dim_n, 3]` representing the 3D
  818. coordinates of vertices.
  819. faces: Tensor of shape `[dim_1, ..., dim_n, 3]` containing indices of
  820. vertices within each triangle.
  821. colors: Tensor of shape `[dim_1, ..., dim_n, 3]` containing colors for each
  822. vertex.
  823. display_name: If set, will be used as the display name in TensorBoard.
  824. Defaults to `name`.
  825. description: A longform readable description of the summary data. Markdown
  826. is supported.
  827. config_dict: Dictionary with ThreeJS classes names and configuration.
  828. Returns:
  829. Merged summary for mesh/point cloud representation.
  830. """
  831. from tensorboard.plugins.mesh import metadata
  832. from tensorboard.plugins.mesh.plugin_data_pb2 import MeshPluginData
  833. json_config = _get_json_config(config_dict)
  834. summaries = []
  835. tensors = [
  836. (vertices, MeshPluginData.VERTEX),
  837. (faces, MeshPluginData.FACE),
  838. (colors, MeshPluginData.COLOR),
  839. ]
  840. tensors = [tensor for tensor in tensors if tensor[0] is not None]
  841. components = metadata.get_components_bitmask(
  842. [content_type for (tensor, content_type) in tensors]
  843. )
  844. for tensor, content_type in tensors:
  845. summaries.append(
  846. _get_tensor_summary(
  847. tag,
  848. display_name,
  849. description,
  850. tensor,
  851. content_type,
  852. components,
  853. json_config,
  854. )
  855. )
  856. return Summary(value=summaries)