import torch import torch.onnx from torchvision import models # 定义保存的路径 model_path = 'models/mobilenet_model.pth' # 加载训练好的MobileNetV2模型 model = models.mobilenet_v2(pretrained=False) # 修改最后的分类器层以适应你的数据集的类别数量(9个类) num_classes = 9 model.classifier[1] = torch.nn.Linear(model.classifier[1].in_features, num_classes) # 加载训练好的权重 model.load_state_dict(torch.load(model_path)) # 切换模型为评估模式 model.eval() # 创建一个示例输入(假设输入大小为 [batch_size, 3, 224, 224],这对应于一张RGB图片) dummy_input = torch.randn(1, 3, 224, 224) # 导出模型为ONNX格式 onnx_path = "mobilenet_model.onnx" torch.onnx.export( model, # 要导出的模型 dummy_input, # 模型的输入示例 onnx_path, # 导出的ONNX模型保存路径 export_params=True, # 保存训练好的参数 opset_version=11, # ONNX opset版本(通常选择较新的版本,如11) do_constant_folding=True,# 是否执行常量折叠以优化模型 input_names=['input'], # 输入节点的名称 output_names=['output'], # 输出节点的名称 dynamic_axes={'input': {0: 'batch_size'}, # 动态轴允许输入的batch_size可变 'output': {0: 'batch_size'}} ) print(f"Model successfully converted to ONNX format at: {onnx_path}")