detector.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. import os
  2. import khandy
  3. import numpy as np
  4. from .base import OnnxModel
  5. from .base import check_image_dtype_and_shape
  6. class InsectDetector(OnnxModel):
  7. def __init__(self):
  8. current_dir = os.path.dirname(os.path.abspath(__file__))
  9. model_path = os.path.join(current_dir, 'models/quarrying_insect_detector.onnx')
  10. self.input_width = 640
  11. self.input_height = 640
  12. super(InsectDetector, self).__init__(model_path)
  13. def _preprocess(self, image):
  14. check_image_dtype_and_shape(image)
  15. # image size normalization
  16. image, scale, pad_left, pad_top = khandy.letterbox_image(
  17. image, self.input_width, self.input_height, 0, return_scale=True)
  18. # image channel normalization
  19. image = khandy.normalize_image_channel(image, swap_rb=True)
  20. # image dtype normalization
  21. image = khandy.rescale_image(image, 'auto', np.float32)
  22. # to tensor
  23. image = np.transpose(image, (2,0,1))
  24. image = np.expand_dims(image, axis=0)
  25. return image, scale, pad_left, pad_top
  26. def _post_process(self, outputs_list, scale, pad_left, pad_top, conf_thresh, iou_thresh):
  27. pred = outputs_list[0][0]
  28. pass_t = pred[:, 4] > conf_thresh
  29. pred = pred[pass_t]
  30. boxes = khandy.convert_boxes_format(pred[:, :4], 'cxcywh', 'xyxy')
  31. boxes = khandy.unletterbox_2d_points(boxes, scale, pad_left, pad_top, False)
  32. confs = np.max(pred[:, 5:] * pred[:, 4:5], axis=-1)
  33. classes = np.argmax(pred[:, 5:] * pred[:, 4:5], axis=-1)
  34. keep = khandy.non_max_suppression(boxes, confs, iou_thresh)
  35. return boxes[keep], confs[keep], classes[keep]
  36. def detect(self, image, conf_thresh=0.5, iou_thresh=0.5):
  37. image, scale, pad_left, pad_top = self._preprocess(image)
  38. outputs_list = self.forward(image)
  39. boxes, confs, classes = self._post_process(
  40. outputs_list,
  41. scale=scale,
  42. pad_left=pad_left,
  43. pad_top=pad_top,
  44. conf_thresh=conf_thresh,
  45. iou_thresh=iou_thresh)
  46. return boxes, confs, classes