transform_onnx.py 1.4 KB

12345678910111213141516171819202122232425262728293031
  1. import torch
  2. import torch.onnx
  3. from torchvision import models
  4. # 定义保存的路径
  5. model_path = 'models/mobilenet_model.pth'
  6. # 加载训练好的MobileNetV2模型
  7. model = models.mobilenet_v2(pretrained=False)
  8. # 修改最后的分类器层以适应你的数据集的类别数量(9个类)
  9. num_classes = 9
  10. model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, num_classes)
  11. # 加载训练好的权重
  12. model.load_state_dict(torch.load(model_path))
  13. # 切换模型为评估模式
  14. model.eval()
  15. # 创建一个示例输入(假设输入大小为 [batch_size, 3, 224, 224],这对应于一张RGB图片)
  16. dummy_input = torch.randn(1, 3, 224, 224)
  17. # 导出模型为ONNX格式
  18. onnx_path = "mobilenet_model.onnx"
  19. torch.onnx.export(
  20. model, # 要导出的模型
  21. dummy_input, # 模型的输入示例
  22. onnx_path, # 导出的ONNX模型保存路径
  23. export_params=True, # 保存训练好的参数
  24. opset_version=11, # ONNX opset版本(通常选择较新的版本,如11)
  25. do_constant_folding=True,# 是否执行常量折叠以优化模型
  26. input_names=['input'], # 输入节点的名称
  27. output_names=['output'], # 输出节点的名称
  28. dynamic_axes={'input': {0: 'batch_size'}, # 动态轴允许输入的batch_size可变
  29. 'output': {0: 'batch_size'}}
  30. )
  31. print(f"Model successfully converted to ONNX format at: {onnx_path}")