_utils.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. # mypy: allow-untyped-defs
  2. import numpy as np
  3. # Functions for converting
  4. def figure_to_image(figures, close=True):
  5. """Render matplotlib figure to numpy format.
  6. Note that this requires the ``matplotlib`` package.
  7. Args:
  8. figures (matplotlib.pyplot.figure or list of figures): figure or a list of figures
  9. close (bool): Flag to automatically close the figure
  10. Returns:
  11. numpy.array: image in [CHW] order
  12. """
  13. import matplotlib.pyplot as plt
  14. import matplotlib.backends.backend_agg as plt_backend_agg
  15. def render_to_rgb(figure):
  16. canvas = plt_backend_agg.FigureCanvasAgg(figure)
  17. canvas.draw()
  18. data: np.ndarray = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8)
  19. w, h = figure.canvas.get_width_height()
  20. image_hwc = data.reshape([h, w, 4])[:, :, 0:3]
  21. image_chw = np.moveaxis(image_hwc, source=2, destination=0)
  22. if close:
  23. plt.close(figure)
  24. return image_chw
  25. if isinstance(figures, list):
  26. images = [render_to_rgb(figure) for figure in figures]
  27. return np.stack(images)
  28. else:
  29. image = render_to_rgb(figures)
  30. return image
  31. def _prepare_video(V):
  32. """
  33. Convert a 5D tensor into 4D tensor.
  34. Convesrion is done from [batchsize, time(frame), channel(color), height, width] (5D tensor)
  35. to [time(frame), new_width, new_height, channel] (4D tensor).
  36. A batch of images are spreaded to a grid, which forms a frame.
  37. e.g. Video with batchsize 16 will have a 4x4 grid.
  38. """
  39. b, t, c, h, w = V.shape
  40. if V.dtype == np.uint8:
  41. V = np.float32(V) / 255.0
  42. def is_power2(num):
  43. return num != 0 and ((num & (num - 1)) == 0)
  44. # pad to nearest power of 2, all at once
  45. if not is_power2(V.shape[0]):
  46. len_addition = int(2 ** V.shape[0].bit_length() - V.shape[0])
  47. V = np.concatenate((V, np.zeros(shape=(len_addition, t, c, h, w))), axis=0)
  48. n_rows = 2 ** ((b.bit_length() - 1) // 2)
  49. n_cols = V.shape[0] // n_rows
  50. V = np.reshape(V, newshape=(n_rows, n_cols, t, c, h, w))
  51. V = np.transpose(V, axes=(2, 0, 4, 1, 5, 3))
  52. V = np.reshape(V, newshape=(t, n_rows * h, n_cols * w, c))
  53. return V
  54. def make_grid(I, ncols=8):
  55. # I: N1HW or N3HW
  56. assert isinstance(I, np.ndarray), "plugin error, should pass numpy array here"
  57. if I.shape[1] == 1:
  58. I = np.concatenate([I, I, I], 1)
  59. assert I.ndim == 4 and I.shape[1] == 3
  60. nimg = I.shape[0]
  61. H = I.shape[2]
  62. W = I.shape[3]
  63. ncols = min(nimg, ncols)
  64. nrows = int(np.ceil(float(nimg) / ncols))
  65. canvas = np.zeros((3, H * nrows, W * ncols), dtype=I.dtype)
  66. i = 0
  67. for y in range(nrows):
  68. for x in range(ncols):
  69. if i >= nimg:
  70. break
  71. canvas[:, y * H : (y + 1) * H, x * W : (x + 1) * W] = I[i]
  72. i = i + 1
  73. return canvas
  74. # if modality == 'IMG':
  75. # if x.dtype == np.uint8:
  76. # x = x.astype(np.float32) / 255.0
  77. def convert_to_HWC(tensor, input_format): # tensor: numpy array
  78. assert len(set(input_format)) == len(
  79. input_format
  80. ), f"You can not use the same dimension shordhand twice. input_format: {input_format}"
  81. assert len(tensor.shape) == len(
  82. input_format
  83. ), f"size of input tensor and input format are different. \
  84. tensor shape: {tensor.shape}, input_format: {input_format}"
  85. input_format = input_format.upper()
  86. if len(input_format) == 4:
  87. index = [input_format.find(c) for c in "NCHW"]
  88. tensor_NCHW = tensor.transpose(index)
  89. tensor_CHW = make_grid(tensor_NCHW)
  90. return tensor_CHW.transpose(1, 2, 0)
  91. if len(input_format) == 3:
  92. index = [input_format.find(c) for c in "HWC"]
  93. tensor_HWC = tensor.transpose(index)
  94. if tensor_HWC.shape[2] == 1:
  95. tensor_HWC = np.concatenate([tensor_HWC, tensor_HWC, tensor_HWC], 2)
  96. return tensor_HWC
  97. if len(input_format) == 2:
  98. index = [input_format.find(c) for c in "HW"]
  99. tensor = tensor.transpose(index)
  100. tensor = np.stack([tensor, tensor, tensor], 2)
  101. return tensor