import glob import torch from torch.utils import data from PIL import Image import numpy as np from torchvision import transforms import matplotlib.pyplot as plt import torch.nn as nn import torch.optim as optim from torchvision import models import os # 导入os模块以检查和创建目录 # 检查是否有可用的GPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 设置Matplotlib的中文字体 plt.rcParams['font.sans-serif'] = ['SimHei'] # 指定默认字体为黑体 plt.rcParams['axes.unicode_minus'] = False # 解决保存图像时负号'-'显示为方块的问题 class Mydatasetpro(data.Dataset): def __init__(self, img_paths, labels, transform): self.imgs = img_paths self.labels = labels self.transforms = transform def __getitem__(self, index): img = self.imgs[index] label = self.labels[index] try: pil_img = Image.open(img).convert('RGB') except Exception as e: print(f"Error loading image {img}: {e}") return None, None data = self.transforms(pil_img) return data, label def __len__(self): return len(self.imgs) # 获取所有图片路径 all_imgs_path = glob.glob(r'/image-recognition/dataset\*\*.jpg') print(f"Found {len(all_imgs_path)} images.") species = ['兔毫盏', '凤纹盏', '剪纸贴花盏', '木叶天目盏', '梅瓶', '梅纹盏', '玳瑁釉盏', '鹧鸪斑盏', '黑釉盏'] species_to_id = dict((c, i) for i, c in enumerate(species)) id_to_species = dict((v, k) for k, v in species_to_id.items()) all_labels = [] for img in all_imgs_path: for i, c in enumerate(species): if c in img: all_labels.append(i) assert len(all_imgs_path) == len(all_labels), "Image count and label count do not match" print(f"Labels for all images: {all_labels}") # 数据处理转换 transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(), # 随机水平翻转 transforms.RandomRotation(10), # 随机旋转 -10 到 10 度 transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2), # 调整亮度、对比度、饱和度和色调 transforms.ToTensor() ]) # 划分训练集和测试集 index = np.random.permutation(len(all_imgs_path)) all_imgs_path = np.array(all_imgs_path)[index] all_labels = np.array(all_labels)[index] # 80% 训练集,20% 测试集 s = int(len(all_imgs_path) * 0.8) train_imgs, test_imgs = all_imgs_path[:s], all_imgs_path[s:] train_labels, test_labels = all_labels[:s], all_labels[s:] train_ds = Mydatasetpro(train_imgs, train_labels, transform) test_ds = Mydatasetpro(test_imgs, test_labels, transform) BATCH_SIZE = 10 train_dl = data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True) test_dl = data.DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False) num_classes = 9 # 加载预训练的MobileNet_v2模型 model = models.mobilenet_v2(pretrained=True) # 修改最后的分类器层 model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes) # 将模型移动到设备(GPU/CPU) model = model.to(device) # 定义损失函数和优化器 criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001) # 训练函数 def train_model(model, train_dl, val_dl, num_epochs=10): for epoch in range(num_epochs): model.train() running_loss = 0.0 correct = 0 total = 0 for inputs, labels in train_dl: inputs, labels = inputs.to(device), labels.long().to(device) outputs = model(inputs) loss = criterion(outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step() running_loss += loss.item() _, predicted = torch.max(outputs, 1) total += labels.size(0) correct += (predicted == labels).sum().item() train_loss = running_loss / len(train_dl) train_accuracy = 100 * correct / total print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {train_loss:.4f}, Accuracy: {train_accuracy:.2f}%") validate_model(model, val_dl) # 验证函数 def validate_model(model, val_dl): model.eval() correct = 0 total = 0 val_loss = 0.0 with torch.no_grad(): for inputs, labels in val_dl: inputs, labels = inputs.to(device), labels.long().to(device) outputs = model(inputs) loss = criterion(outputs, labels) val_loss += loss.item() _, predicted = torch.max(outputs, 1) total += labels.size(0) correct += (predicted == labels).sum().item() val_loss /= len(val_dl) if total > 0: val_accuracy = 100 * correct / total else: val_accuracy = 0 print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%") # 训练模型 train_model(model, train_dl, test_dl, num_epochs=30) # 定义保存模型的路径 model_dir = 'models' model_path = os.path.join(model_dir, 'mobilenet_model.pth') # 检查是否存在models目录,不存在则创建 if not os.path.exists(model_dir): os.makedirs(model_dir) # 保存训练好的模型 torch.save(model.state_dict(), model_path) print(f"Model saved to {model_path}")