12345678910111213141516171819202122232425262728293031 |
- import torch
- import torch.onnx
- from torchvision import models
- # 定义保存的路径
- model_path = 'models/mobilenet_model.pth'
- # 加载训练好的MobileNetV2模型
- model = models.mobilenet_v2(pretrained=False)
- # 修改最后的分类器层以适应你的数据集的类别数量(9个类)
- num_classes = 9
- model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, num_classes)
- # 加载训练好的权重
- model.load_state_dict(torch.load(model_path))
- # 切换模型为评估模式
- model.eval()
- # 创建一个示例输入(假设输入大小为 [batch_size, 3, 224, 224],这对应于一张RGB图片)
- dummy_input = torch.randn(1, 3, 224, 224)
- # 导出模型为ONNX格式
- onnx_path = "mobilenet_model.onnx"
- torch.onnx.export(
- model, # 要导出的模型
- dummy_input, # 模型的输入示例
- onnx_path, # 导出的ONNX模型保存路径
- export_params=True, # 保存训练好的参数
- opset_version=11, # ONNX opset版本(通常选择较新的版本,如11)
- do_constant_folding=True,# 是否执行常量折叠以优化模型
- input_names=['input'], # 输入节点的名称
- output_names=['output'], # 输出节点的名称
- dynamic_axes={'input': {0: 'batch_size'}, # 动态轴允许输入的batch_size可变
- 'output': {0: 'batch_size'}}
- )
- print(f"Model successfully converted to ONNX format at: {onnx_path}")
|