123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051 |
- import onnxruntime
- import numpy as np
- class OnnxModel(object):
- def __init__(self, model_path):
- sess_options = onnxruntime.SessionOptions()
- # # Set graph optimization level to ORT_ENABLE_EXTENDED to enable bert optimization.
- # sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
- # # Use OpenMP optimizations. Only useful for CPU, has little impact for GPUs.
- # sess_options.intra_op_num_threads = multiprocessing.cpu_count()
- onnx_gpu = (onnxruntime.get_device() == 'GPU')
- providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if onnx_gpu else ['CPUExecutionProvider']
- self.sess = onnxruntime.InferenceSession(model_path, sess_options, providers=providers)
- self._input_names = [item.name for item in self.sess.get_inputs()]
- self._output_names = [item.name for item in self.sess.get_outputs()]
-
- @property
- def input_names(self):
- return self._input_names
-
- @property
- def output_names(self):
- return self._output_names
-
- def forward(self, inputs):
- to_list_flag = False
- if not isinstance(inputs, (tuple, list)):
- inputs = [inputs]
- to_list_flag = True
- input_feed = {name: input for name, input in zip(self.input_names, inputs)}
- outputs = self.sess.run(self.output_names, input_feed)
- if (len(self.output_names) == 1) and to_list_flag:
- return outputs[0]
- else:
- return outputs
-
- def check_image_dtype_and_shape(image):
- if not isinstance(image, np.ndarray):
- raise Exception(f'image is not np.ndarray!')
- if isinstance(image.dtype, (np.uint8, np.uint16)):
- raise Exception(f'Unsupported image dtype, only support uint8 and uint16, got {image.dtype}!')
- if image.ndim not in {2, 3}:
- raise Exception(f'Unsupported image dimension number, only support 2 and 3, got {image.ndim}!')
- if image.ndim == 3:
- num_channels = image.shape[-1]
- if num_channels not in {1, 3, 4}:
- raise Exception(f'Unsupported image channel number, only support 1, 3 and 4, got {num_channels}!')
|