_embedding.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. # mypy: allow-untyped-defs
  2. import math
  3. import numpy as np
  4. from ._convert_np import make_np
  5. from ._utils import make_grid
  6. from tensorboard.compat import tf
  7. from tensorboard.plugins.projector.projector_config_pb2 import EmbeddingInfo
  8. _HAS_GFILE_JOIN = hasattr(tf.io.gfile, "join")
  9. def _gfile_join(a, b):
  10. # The join API is different between tensorboard's TF stub and TF:
  11. # https://github.com/tensorflow/tensorboard/issues/6080
  12. # We need to try both because `tf` may point to either the stub or the real TF.
  13. if _HAS_GFILE_JOIN:
  14. return tf.io.gfile.join(a, b)
  15. else:
  16. fs = tf.io.gfile.get_filesystem(a)
  17. return fs.join(a, b)
  18. def make_tsv(metadata, save_path, metadata_header=None):
  19. if not metadata_header:
  20. metadata = [str(x) for x in metadata]
  21. else:
  22. assert len(metadata_header) == len(
  23. metadata[0]
  24. ), "len of header must be equal to the number of columns in metadata"
  25. metadata = ["\t".join(str(e) for e in l) for l in [metadata_header] + metadata]
  26. metadata_bytes = tf.compat.as_bytes("\n".join(metadata) + "\n")
  27. with tf.io.gfile.GFile(_gfile_join(save_path, "metadata.tsv"), "wb") as f:
  28. f.write(metadata_bytes)
  29. # https://github.com/tensorflow/tensorboard/issues/44 image label will be squared
  30. def make_sprite(label_img, save_path):
  31. from PIL import Image
  32. from io import BytesIO
  33. # this ensures the sprite image has correct dimension as described in
  34. # https://www.tensorflow.org/get_started/embedding_viz
  35. nrow = int(math.ceil((label_img.size(0)) ** 0.5))
  36. arranged_img_CHW = make_grid(make_np(label_img), ncols=nrow)
  37. # augment images so that #images equals nrow*nrow
  38. arranged_augment_square_HWC = np.zeros(
  39. (arranged_img_CHW.shape[2], arranged_img_CHW.shape[2], 3)
  40. )
  41. arranged_img_HWC = arranged_img_CHW.transpose(1, 2, 0) # chw -> hwc
  42. arranged_augment_square_HWC[: arranged_img_HWC.shape[0], :, :] = arranged_img_HWC
  43. im = Image.fromarray(np.uint8((arranged_augment_square_HWC * 255).clip(0, 255)))
  44. with BytesIO() as buf:
  45. im.save(buf, format="PNG")
  46. im_bytes = buf.getvalue()
  47. with tf.io.gfile.GFile(_gfile_join(save_path, "sprite.png"), "wb") as f:
  48. f.write(im_bytes)
  49. def get_embedding_info(metadata, label_img, subdir, global_step, tag):
  50. info = EmbeddingInfo()
  51. info.tensor_name = f"{tag}:{str(global_step).zfill(5)}"
  52. info.tensor_path = _gfile_join(subdir, "tensors.tsv")
  53. if metadata is not None:
  54. info.metadata_path = _gfile_join(subdir, "metadata.tsv")
  55. if label_img is not None:
  56. info.sprite.image_path = _gfile_join(subdir, "sprite.png")
  57. info.sprite.single_image_dim.extend([label_img.size(3), label_img.size(2)])
  58. return info
  59. def write_pbtxt(save_path, contents):
  60. config_path = _gfile_join(save_path, "projector_config.pbtxt")
  61. with tf.io.gfile.GFile(config_path, "wb") as f:
  62. f.write(tf.compat.as_bytes(contents))
  63. def make_mat(matlist, save_path):
  64. with tf.io.gfile.GFile(_gfile_join(save_path, "tensors.tsv"), "wb") as f:
  65. for x in matlist:
  66. x = [str(i.item()) for i in x]
  67. f.write(tf.compat.as_bytes("\t".join(x) + "\n"))