123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131 |
- import glob
- import torch
- from torch.utils import data
- from PIL import Image
- import numpy as np
- from torchvision import transforms
- import matplotlib.pyplot as plt
- import torch.nn as nn
- import torch.optim as optim
- from torchvision import models
- import os # 导入os模块以检查和创建目录
- # 检查是否有可用的GPU
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- # 设置Matplotlib的中文字体
- plt.rcParams['font.sans-serif'] = ['SimHei'] # 指定默认字体为黑体
- plt.rcParams['axes.unicode_minus'] = False # 解决保存图像时负号'-'显示为方块的问题
- class Mydatasetpro(data.Dataset):
- def __init__(self, img_paths, labels, transform):
- self.imgs = img_paths
- self.labels = labels
- self.transforms = transform
- def __getitem__(self, index):
- img = self.imgs[index]
- label = self.labels[index]
- try:
- pil_img = Image.open(img).convert('RGB')
- except Exception as e:
- print(f"Error loading image {img}: {e}")
- return None, None
- data = self.transforms(pil_img)
- return data, label
- def __len__(self):
- return len(self.imgs)
- # 获取所有图片路径
- all_imgs_path = glob.glob(r'/image-recognition/dataset\*\*.jpg')
- print(f"Found {len(all_imgs_path)} images.")
- species = ['兔毫盏', '凤纹盏', '剪纸贴花盏', '木叶天目盏', '梅瓶', '梅纹盏', '玳瑁釉盏', '鹧鸪斑盏', '黑釉盏']
- species_to_id = dict((c, i) for i, c in enumerate(species))
- id_to_species = dict((v, k) for k, v in species_to_id.items())
- all_labels = []
- for img in all_imgs_path:
- for i, c in enumerate(species):
- if c in img:
- all_labels.append(i)
- assert len(all_imgs_path) == len(all_labels), "Image count and label count do not match"
- print(f"Labels for all images: {all_labels}")
- # 数据处理转换
- transform = transforms.Compose([
- transforms.Resize((224, 224)),
- transforms.RandomHorizontalFlip(), # 随机水平翻转
- transforms.RandomRotation(10), # 随机旋转 -10 到 10 度
- transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2), # 调整亮度、对比度、饱和度和色调
- transforms.ToTensor()
- ])
- # 划分训练集和测试集
- index = np.random.permutation(len(all_imgs_path))
- all_imgs_path = np.array(all_imgs_path)[index]
- all_labels = np.array(all_labels)[index]
- # 80% 训练集,20% 测试集
- s = int(len(all_imgs_path) * 0.8)
- train_imgs, test_imgs = all_imgs_path[:s], all_imgs_path[s:]
- train_labels, test_labels = all_labels[:s], all_labels[s:]
- train_ds = Mydatasetpro(train_imgs, train_labels, transform)
- test_ds = Mydatasetpro(test_imgs, test_labels, transform)
- train_dl = data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
- test_dl = data.DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)
- num_classes = 9
- # 加载预训练的MobileNet_v2模型
- model = models.mobilenet_v2(pretrained=True)
- # 修改最后的分类器层
- model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
- # 将模型移动到设备(GPU/CPU)
- model = model.to(device)
- # 定义损失函数和优化器
- criterion = nn.CrossEntropyLoss()
- optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
- # 训练函数
- def train_model(model, train_dl, val_dl, num_epochs=10):
- for epoch in range(num_epochs):
- model.train()
- running_loss = 0.0
- correct = 0
- total = 0
- for inputs, labels in train_dl:
- inputs, labels = inputs.to(device), labels.long().to(device)
- outputs = model(inputs)
- loss = criterion(outputs, labels)
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
- running_loss += loss.item()
- _, predicted = torch.max(outputs, 1)
- total += labels.size(0)
- correct += (predicted == labels).sum().item()
- train_loss = running_loss / len(train_dl)
- train_accuracy = 100 * correct / total
- print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {train_loss:.4f}, Accuracy: {train_accuracy:.2f}%")
- validate_model(model, val_dl)
- # 验证函数
- def validate_model(model, val_dl):
- model.eval()
- correct = 0
- total = 0
- val_loss = 0.0
- with torch.no_grad():
- for inputs, labels in val_dl:
- inputs, labels = inputs.to(device), labels.long().to(device)
- outputs = model(inputs)
- loss = criterion(outputs, labels)
- val_loss += loss.item()
- _, predicted = torch.max(outputs, 1)
- total += labels.size(0)
- correct += (predicted == labels).sum().item()
- val_loss /= len(val_dl)
- if total > 0:
- val_accuracy = 100 * correct / total
- else:
- val_accuracy = 0
- print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%")
- # 训练模型
- train_model(model, train_dl, test_dl, num_epochs=30)
- # 定义保存模型的路径
- model_dir = 'models'
- model_path = os.path.join(model_dir, 'mobilenet_model.pth')
- # 检查是否存在models目录,不存在则创建
- if not os.path.exists(model_dir):
- os.makedirs(model_dir)
- # 保存训练好的模型
- torch.save(model.state_dict(), model_path)
- print(f"Model saved to {model_path}")