identifier.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import os
  2. import copy
  3. from collections import OrderedDict
  4. import khandy
  5. import numpy as np
  6. from .base import OnnxModel
  7. from .base import check_image_dtype_and_shape
  8. class InsectIdentifier(OnnxModel):
  9. def __init__(self):
  10. current_dir = os.path.dirname(os.path.abspath(__file__))
  11. model_path = os.path.join(current_dir, 'models/quarrying_insect_identifier.onnx')
  12. label_map_path = os.path.join(current_dir, 'models/quarrying_insectid_label_map.txt')
  13. super(InsectIdentifier, self).__init__(model_path)
  14. self.label_name_dict = self._get_label_name_dict(label_map_path)
  15. self.names = [self.label_name_dict[i]['chinese_name'] for i in range(len(self.label_name_dict))]
  16. self.num_classes = len(self.label_name_dict)
  17. @staticmethod
  18. def _get_label_name_dict(filename):
  19. records = khandy.load_list(filename)
  20. label_name_dict = {}
  21. for record in records:
  22. label, chinese_name, latin_name = record.split(',')
  23. label_name_dict[int(label)] = OrderedDict([('chinese_name', chinese_name),
  24. ('latin_name', latin_name)])
  25. return label_name_dict
  26. @staticmethod
  27. def _preprocess(image):
  28. check_image_dtype_and_shape(image)
  29. # image size normalization
  30. image = khandy.letterbox_image(image, 224, 224)
  31. # image channel normalization
  32. image = khandy.normalize_image_channel(image, swap_rb=True)
  33. # image dtype normalization
  34. # image dtype and value range normalization
  35. mean, stddev = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
  36. image = khandy.normalize_image_value(image, mean, stddev, 'auto')
  37. # to tensor
  38. image = np.transpose(image, (2,0,1))
  39. image = np.expand_dims(image, axis=0)
  40. return image
  41. def predict(self, image):
  42. inputs = self._preprocess(image)
  43. logits = self.forward(inputs)
  44. probs = khandy.softmax(logits)
  45. return probs
  46. def identify(self, image, topk=5):
  47. assert isinstance(topk, int)
  48. if topk <= 0 or topk > self.num_classes:
  49. topk = self.num_classes
  50. probs = self.predict(image)
  51. topk_probs, topk_indices = khandy.top_k(probs, topk)
  52. results = []
  53. for ind, prob in zip(topk_indices[0], topk_probs[0]):
  54. one_result = copy.deepcopy(self.label_name_dict[ind])
  55. one_result['probability'] = prob
  56. results.append(one_result)
  57. return results