base.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import onnxruntime
  2. import numpy as np
  3. class OnnxModel(object):
  4. def __init__(self, model_path):
  5. sess_options = onnxruntime.SessionOptions()
  6. # # Set graph optimization level to ORT_ENABLE_EXTENDED to enable bert optimization.
  7. # sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
  8. # # Use OpenMP optimizations. Only useful for CPU, has little impact for GPUs.
  9. # sess_options.intra_op_num_threads = multiprocessing.cpu_count()
  10. onnx_gpu = (onnxruntime.get_device() == 'GPU')
  11. providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if onnx_gpu else ['CPUExecutionProvider']
  12. self.sess = onnxruntime.InferenceSession(model_path, sess_options, providers=providers)
  13. self._input_names = [item.name for item in self.sess.get_inputs()]
  14. self._output_names = [item.name for item in self.sess.get_outputs()]
  15. @property
  16. def input_names(self):
  17. return self._input_names
  18. @property
  19. def output_names(self):
  20. return self._output_names
  21. def forward(self, inputs):
  22. to_list_flag = False
  23. if not isinstance(inputs, (tuple, list)):
  24. inputs = [inputs]
  25. to_list_flag = True
  26. input_feed = {name: input for name, input in zip(self.input_names, inputs)}
  27. outputs = self.sess.run(self.output_names, input_feed)
  28. if (len(self.output_names) == 1) and to_list_flag:
  29. return outputs[0]
  30. else:
  31. return outputs
  32. def check_image_dtype_and_shape(image):
  33. if not isinstance(image, np.ndarray):
  34. raise Exception(f'image is not np.ndarray!')
  35. if isinstance(image.dtype, (np.uint8, np.uint16)):
  36. raise Exception(f'Unsupported image dtype, only support uint8 and uint16, got {image.dtype}!')
  37. if image.ndim not in {2, 3}:
  38. raise Exception(f'Unsupported image dimension number, only support 2 and 3, got {image.ndim}!')
  39. if image.ndim == 3:
  40. num_channels = image.shape[-1]
  41. if num_channels not in {1, 3, 4}:
  42. raise Exception(f'Unsupported image channel number, only support 1, 3 and 4, got {num_channels}!')