模型设计思路
更新: 2025/7/18 字数: 0 字 时长: 0 分钟
网络架构选择与优化
预训练模型基础
- EfficientNet-B4/B5:在ImageNet上预训练,参数效率高,适合菜品细粒度分类
- ResNet-152/ResNeXt-101:深层网络,特征提取能力强
- Vision Transformer (ViT):注意力机制对菜品细节识别效果好
- ConvNeXt:结合CNN和Transformer优势的现代架构
多尺度特征融合
# 示例:多尺度特征提取
class MultiscaleFoodNet(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.backbone = timm.create_model('efficientnet_b4', pretrained=True)
self.global_pool = nn.AdaptiveAvgPool2d(1)
self.local_pool = nn.AdaptiveMaxPool2d(1)
self.classifier = nn.Linear(1792*2, num_classes)
def forward(self, x):
features = self.backbone.forward_features(x)
global_feat = self.global_pool(features).flatten(1)
local_feat = self.local_pool(features).flatten(1)
combined = torch.cat([global_feat, local_feat], dim=1)
return self.classifier(combined)损失函数设计
Focal Loss
- 解决类别不平衡问题
- 专注于困难样本的学习
class FocalLoss(nn.Module):
def __init__(self, alpha=1, gamma=2):
super().__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, inputs, targets):
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-ce_loss)
focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
return focal_loss.mean()Label Smoothing
- 防止过拟合,提高泛化能力
- 对于相似菜品(如不同做法的鱼)特别有效
Center Loss
- 增加类内聚合度,提高特征判别性
- 适合菜品间细微差异的识别
数据增强策略
几何变换增强
transform_train = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.RandomAffine(degrees=10, translate=(0.1, 0.1)),
])高级增强技术
- CutMix/MixUp:混合不同图像,提高泛化能力
- AutoAugment:自动搜索最优增强策略
- AugMax:针对食物图像的专门增强
注意力机制集成
通道注意力(SE/CBAM)
class SEBlock(nn.Module):
def __init__(self, channels, reduction=16):
super().__init__()
self.squeeze = nn.AdaptiveAvgPool2d(1)
self.excitation = nn.Sequential(
nn.Linear(channels, channels // reduction),
nn.ReLU(inplace=True),
nn.Linear(channels // reduction, channels),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.squeeze(x).view(b, c)
y = self.excitation(y).view(b, c, 1, 1)
return x * y空间注意力
- 关注菜品的关键区域(形状、颜色、纹理)
- 忽略背景干扰
模型集成策略
多模型投票
models = [efficientnet_b4, resnet152, vit_base]
predictions = []
for model in models:
pred = model(input_image)
predictions.append(F.softmax(pred, dim=1))
ensemble_pred = torch.mean(torch.stack(predictions), dim=0)知识蒸馏
- 用大模型指导小模型学习
- 保持精度的同时提升推理速度
训练策略优化
学习率调度
- Cosine Annealing with Restart
- 在训练后期进行fine-tuning
渐进式训练
# 第一阶段:冻结backbone,只训练分类头
for param in model.backbone.parameters():
param.requires_grad = False
# 第二阶段:解冻所有层,低学习率训练
for param in model.parameters():
param.requires_grad = True
optimizer = optim.Adam(model.parameters(), lr=1e-5)正则化技术
- Dropout (0.3-0.5)
- Weight Decay (1e-4)
- Early Stopping
针对菜品特点的优化
多任务学习
- 同时预测菜品类别和属性(口味、烹饪方法、主料)
- 共享特征表示,提高泛化能力
层次化分类
- 先分大类(中式、西式、日式)
- 再细分具体菜品
- 减少错误分类的影响
区域特征学习
- 使用目标检测预处理,裁剪菜品主体
- 减少餐具、背景干扰
开发流程
更新: 2025/7/18 字数: 0 字 时长: 0 分钟
1. 数据准备阶段
数据收集
- 收集原始图像数据(测试训练使用food101数据集 food101 | TensorFlow Datasets)
- 确保数据的多样性和代表性
- 考虑不同光照、角度、背景条件
数据预处理
- 图像尺寸统一(通常resize到224x224或更高)
- 数据清洗,去除低质量图像
- 标注质量检查和修正
数据划分
- 训练集(70-80%)
- 验证集(10-15%)
- 测试集(10-15%)
- 确保各类别分布均衡
2. 模型构建阶段
架构设计
- 选择合适的基础网络(ResNet、EfficientNet等)
- 设计分类头结构
- 添加正则化组件(Dropout、BatchNorm)
超参数初始化
- 学习率、批量大小、优化器选择
- 损失函数设计
- 训练轮数规划
3. 预训练阶段
迁移学习准备
- 加载在ImageNet上预训练的权重
- 冻结backbone参数
- 只训练分类头
初步训练
# 冻结预训练层
for param in model.backbone.parameters():
param.requires_grad = False
# 只训练分类器
optimizer = optim.Adam(model.classifier.parameters(), lr=1e-3)4. 微调阶段(Fine-tuning)
解冻训练
- 解冻部分或全部预训练层
- 使用较小的学习率
- 逐层或分组解冻
学习率策略
# 不同层使用不同学习率
optimizer = optim.Adam([
{'params': model.backbone.parameters(), 'lr': 1e-5},
{'params': model.classifier.parameters(), 'lr': 1e-3}
])5. 正式训练阶段
完整训练
- 使用完整的训练策略
- 应用数据增强
- 监控训练和验证指标
动态调整
- 学习率调度(StepLR、CosineAnnealingLR)
- 根据验证集表现调整超参数
- Early stopping防止过拟合
6. 验证调优阶段
超参数调优
- 网格搜索或贝叶斯优化
- 交叉验证
- 模型选择和比较
性能分析
- 混淆矩阵分析
- 类别准确率统计
- 错误样本分析
7. 测试评估阶段
最终评估
- 在测试集上评估模型性能
- 计算各项指标(准确率、F1-score、Top-5准确率)
- 生成分类报告
鲁棒性测试
- 对抗样本测试
- 不同数据分布的泛化能力
- 边界情况处理
8. 部署优化阶段
模型压缩
- 量化(INT8)
- 剪枝
- 知识蒸馏
推理优化
- 模型格式转换(ONNX、TensorRT)
- 批处理优化
- 内存使用优化
关键监控指标
训练过程监控
- 训练损失和验证损失
- 训练准确率和验证准确率
- 学习率变化
- 梯度范数
过拟合检测
- 训练验证损失差距
- 验证准确率plateau
- 模型复杂度分析
部署训练
更新: 2025/7/18 字数: 0 字 时长: 0 分钟
pip install torch torchvision timm opencv-python pillow matplotlib scikit-learn seaborn tqdm1.二次改进代码
# 设置GPU和多线程
def setup_device():
"""设置设备和优化GPU使用"""
if torch.cuda.is_available():
device = torch.device('cuda')
print(f'Using GPU: {torch.cuda.get_device_name(0)}')
print(f'GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB')
# 启用cudnn优化
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
# 设置GPU内存分配策略
torch.cuda.empty_cache()
else:
device = torch.device('cpu')
print('Using CPU')
return device
# =================== 优化的数据加载器 ===================
def get_optimized_dataloader(dataset, batch_size=32, shuffle=True, num_workers=4):
"""获取优化的数据加载器"""
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=True, # 加速GPU数据传输
persistent_workers=True, # 保持worker进程
prefetch_factor=2, # 预取因子
drop_last=True if shuffle else False # 训练时丢弃最后一个不完整batch
)
# =================== 改进的训练器类 ===================
class OptimizedFoodRecognitionTrainer:
def __init__(self, model, train_loader, val_loader, device, num_classes=10):
self.model = model.to(device)
self.train_loader = train_loader
self.val_loader = val_loader
self.device = device
self.num_classes = num_classes
# 使用混合精度训练
self.scaler = torch.cuda.amp.GradScaler() if device.type == 'cuda' else None
# 损失函数
self.criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
# 优化器 - 使用更好的学习率
self.optimizer = optim.AdamW(
model.parameters(),
# lr=1e-3,
weight_decay=1e-4,
eps=1e-8
)
# 学习率调度器
self.scheduler = optim.lr_scheduler.OneCycleLR(
self.optimizer,
# max_lr=1e-3,
# steps_per_epoch=len(train_loader),
# epochs=50,
pct_start=0.3,
div_factor=10,
final_div_factor=100
)
def train_epoch(self):
"""训练一个epoch"""
self.model.train()
total_loss = 0.0
correct = 0
total = 0
pbar = tqdm(self.train_loader, desc='Training')
for batch_idx, (data, target) in enumerate(pbar):
# 非阻塞传输到GPU
data = data.to(self.device, non_blocking=True)
target = target.to(self.device, non_blocking=True)
self.optimizer.zero_grad()
# 混合精度训练
if self.scaler:
with torch.cuda.amp.autocast():
output = self.model(data)
loss = self.criterion(output, target)
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
else:
output = self.model(data)
loss = self.criterion(output, target)
loss.backward()
self.optimizer.step()
# 更新学习率
self.scheduler.step()
total_loss += loss.item()
_, predicted = output.max(1)
total += target.size(0)
correct += predicted.eq(target).sum().item()
pbar.set_postfix({
'Loss': f'{loss.item():.4f}',
'Acc': f'{100.*correct/total:.2f}%',
'LR': f'{self.scheduler.get_last_lr()[0]:.6f}'
})
# 定期清理GPU缓存
if batch_idx % 100 == 0:
torch.cuda.empty_cache()
avg_loss = total_loss / len(self.train_loader)
accuracy = 100. * correct / total
return avg_loss, accuracy
def validate(self):
"""验证模型"""
self.model.eval()
total_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for data, target in tqdm(self.val_loader, desc='Validation'):
data = data.to(self.device, non_blocking=True)
target = target.to(self.device, non_blocking=True)
# 混合精度推理
if self.scaler:
with torch.cuda.amp.autocast():
output = self.model(data)
loss = self.criterion(output, target)
else:
output = self.model(data)
loss = self.criterion(output, target)
total_loss += loss.item()
_, predicted = output.max(1)
total += target.size(0)
correct += predicted.eq(target).sum().item()
avg_loss = total_loss / len(self.val_loader)
accuracy = 100. * correct / total
return avg_loss, accuracy
def train(self, epochs=50):
for epoch in range(epochs):
print(f'\nEpoch {epoch+1}/{epochs}')
print('-' * 50)
# 训练
train_loss, train_acc = self.train_epoch()
# 验证
val_loss, val_acc = self.validate()
# 记录
self.train_losses.append(train_loss)
self.val_losses.append(val_loss)
self.val_accuracies.append(val_acc)
print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
# 监控GPU内存使用
if self.device.type == 'cuda':
memory_allocated = torch.cuda.memory_allocated(self.device) / 1024**3
memory_reserved = torch.cuda.memory_reserved(self.device) / 1024**3
print(f'GPU Memory: {memory_allocated:.1f}GB allocated, {memory_reserved:.1f}GB reserved')主要改进点
1. GPU优化
- 添加了混合精度训练 (
torch.cuda.amp) - 使用
non_blocking=True加速数据传输 - 启用
cudnn.benchmark优化 - 定期清理GPU缓存
2. 数据加载优化
- 使用
pin_memory=True加速GPU传输 - 添加
persistent_workers=True保持worker进程 - 设置预取因子
prefetch_factor=2
3. 训练策略改进
- 使用
OneCycleLR学习率调度器 - 添加Label Smoothing
- 改进的checkpoint保存机制
4. 内存管理
- 监控GPU内存使用
- 定期清理GPU缓存
- 优化batch处理
2.参数调整
根据GPU显存大小可以修改batch_size
8GB显存:batch_size=16-24
12GB显存:batch_size=32-48
24GB显存:batch_size=64+
根据以下配置进行代码优化
INFO
GPU:RTX 5090(32GB) * 1 CPU:25 vCPU Intel(R) Xeon(R) Platinum 8470Q 内存:120GB
可以进行以下优化:
增加batch_size参数:可以从32提高到128-256甚至更高,充分利用大显存
增加num_workers:从4提高到8-16,匹配25核CPU
增加prefetch_factor:从2提高到4-6
使用更大的backbone:从'efficientnet_b4'升级到'efficientnet_b7'或其他更大模型
增加图像输入尺寸:在get_transforms函数中修改input_size从224提高到384或448
可以考虑使用torch.compile()进一步加速(PyTorch 2.0+特性)
添加模型并行或数据并行处理
添加torch.nn.DataParallel或torch.nn.parallel.DistributedDataParallel支持
3.数据集设置
data/
├── train/
│ ├── class1/
│ └── class2/
├── val/
└── test/处理数据集,根据所给的官方划分,拆成上面三个目录
pip install scikit-learnimport os
import json
import shutil
from pathlib import Path
from sklearn.model_selection import train_test_split
def organize_food101_dataset():
# Define paths
base_dir = Path("data/food-101")
source_images_dir = base_dir / "images"
meta_dir = base_dir / "meta"
# Target directories
target_base_dir = Path("data")
train_dir = target_base_dir / "train"
val_dir = target_base_dir / "val"
test_dir = target_base_dir / "test"
# Create target directories if they don't exist
for directory in [train_dir, val_dir, test_dir]:
os.makedirs(directory, exist_ok=True)
# Read class names
with open(meta_dir / "classes.txt", "r") as f:
classes = [line.strip() for line in f.readlines()]
print(f"Found {len(classes)} classes")
# Read train and test image paths
train_images = []
with open(meta_dir / "train.txt", "r") as f:
for line in f:
# Format in train.txt is like: class_name/image_id
train_images.append(line.strip())
test_images = []
with open(meta_dir / "test.txt", "r") as f:
for line in f:
test_images.append(line.strip())
print(f"Found {len(train_images)} training images and {len(test_images)} testing images")
# Split training data to create a validation set (e.g., 80% train, 20% val)
train_images, val_images = train_test_split(train_images, test_size=0.2, random_state=42, stratify=[img.split('/')[0] for img in train_images])
print(f"Split into {len(train_images)} training images and {len(val_images)} validation images")
# Function to copy images to target directory
def copy_images(image_paths, target_dir):
for idx, img_path in enumerate(image_paths):
class_name, img_file = img_path.split('/')
# Create class directory if it doesn't exist
class_dir = target_dir / class_name
os.makedirs(class_dir, exist_ok=True)
# Source and destination paths
src_path = source_images_dir / f"{img_path}.jpg"
dst_path = class_dir / f"{img_file}.jpg"
# Copy image
shutil.copy2(src_path, dst_path)
# Print progress
if idx % 1000 == 0:
print(f"Processed {idx}/{len(image_paths)} images...")
# Copy images to respective directories
print("Copying training images...")
copy_images(train_images, train_dir)
print("Copying validation images...")
copy_images(val_images, val_dir)
print("Copying test images...")
copy_images(test_images, test_dir)
print("Dataset organization complete!")
print(f"Train set: {len(train_images)} images")
print(f"Validation set: {len(val_images)} images")
print(f"Test set: {len(test_images)} images")
if __name__ == "__main__":
organize_food101_dataset()TIP
base_dir = Path("data/food-101")修改为自己的数据集目录
版本1
更新: 2025/7/18 字数: 0 字 时长: 0 分钟
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import torchvision.models as models
import timm
import numpy as np
import cv2
from PIL import Image
import os
import json
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
from tqdm import tqdm
import warnings
from datetime import datetime
import gc
warnings.filterwarnings('ignore')
# =================== 设备配置 ===================
def setup_device():
"""设置设备和优化GPU使用"""
if torch.cuda.is_available():
device = torch.device('cuda')
print(f'Using GPU: {torch.cuda.get_device_name(0)}')
print(f'GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024 ** 3:.1f} GB')
# 启用cudnn优化
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
# 设置内存分配策略 - 更激进的内存使用
torch.cuda.set_per_process_memory_fraction(0.98) # 使用98%的GPU内存,原来是0.95
# 启用更高效的内存管理
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
# 清空缓存
torch.cuda.empty_cache()
gc.collect()
else:
device = torch.device('cpu')
print('Using CPU')
return device
# =================== 数据集类 ===================
class FoodDataset(Dataset):
def __init__(self, data_dir, transform=None, is_train=True):
"""
菜品数据集类
Args:
data_dir: 数据目录路径
transform: 数据变换
is_train: 是否为训练集
"""
self.data_dir = data_dir
self.transform = transform
self.is_train = is_train
# 加载数据
self.image_paths = []
self.labels = []
self.class_names = []
# 假设数据结构为: data_dir/class_name/image_files
for class_idx, class_name in enumerate(sorted(os.listdir(data_dir))):
class_path = os.path.join(data_dir, class_name)
if os.path.isdir(class_path):
self.class_names.append(class_name)
for img_file in os.listdir(class_path):
if img_file.lower().endswith(('.png', '.jpg', '.jpeg')):
self.image_paths.append(os.path.join(class_path, img_file))
self.labels.append(class_idx)
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image_path = self.image_paths[idx]
label = self.labels[idx]
# 读取图像
image = Image.open(image_path).convert('RGB')
if self.transform:
image = self.transform(image)
return image, label
# =================== 优化的数据加载器 ===================
def get_optimized_dataloader(dataset, batch_size=256, shuffle=True, num_workers=20):
"""获取优化的数据加载器 - 增大batch size"""
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=True,
persistent_workers=True,
prefetch_factor=12, # 增加预取因子
drop_last=True if shuffle else False
)
# =================== 注意力机制模块 ===================
class SEBlock(nn.Module):
"""Squeeze-and-Excitation Block"""
def __init__(self, channels, reduction=16):
super(SEBlock, self).__init__()
self.squeeze = nn.AdaptiveAvgPool2d(1)
self.excitation = nn.Sequential(
nn.Linear(channels, channels // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channels // reduction, channels, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.squeeze(x).view(b, c)
y = self.excitation(y).view(b, c, 1, 1)
return x * y
class SpatialAttention(nn.Module):
"""空间注意力模块"""
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x_cat = torch.cat([avg_out, max_out], dim=1)
attention = self.sigmoid(self.conv(x_cat))
return x * attention
# =================== 损失函数 ===================
class FocalLoss(nn.Module):
"""Focal Loss for addressing class imbalance"""
def __init__(self, alpha=1, gamma=2, reduction='mean'):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs, targets):
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-ce_loss)
focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
if self.reduction == 'mean':
return focal_loss.mean()
elif self.reduction == 'sum':
return focal_loss.sum()
else:
return focal_loss
class LabelSmoothingLoss(nn.Module):
"""Label Smoothing Loss"""
def __init__(self, num_classes, smoothing=0.1):
super(LabelSmoothingLoss, self).__init__()
self.num_classes = num_classes
self.smoothing = smoothing
self.confidence = 1.0 - smoothing
def forward(self, pred, target):
pred = F.log_softmax(pred, dim=1)
true_dist = torch.zeros_like(pred)
true_dist.fill_(self.smoothing / (self.num_classes - 1))
true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
return torch.mean(torch.sum(-true_dist * pred, dim=1))
# =================== 多尺度特征融合网络 ===================
class MultiscaleFoodNet(nn.Module):
def __init__(self, num_classes, model_name='efficientnet_b7', pretrained=True): # 升级到b7
super(MultiscaleFoodNet, self).__init__()
# 基础网络 - 使用更大的模型
self.backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=0)
# 获取特征维度
with torch.no_grad():
dummy_input = torch.randn(1, 3, 480, 480) # 增大输入尺寸
features = self.backbone(dummy_input)
feature_dim = features.size(1)
# 注意力机制
self.se_block = SEBlock(feature_dim)
self.spatial_attention = SpatialAttention()
# 多尺度池化
self.global_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
# 更复杂的分类头
self.classifier = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(feature_dim * 2, feature_dim),
nn.BatchNorm1d(feature_dim),
nn.ReLU(inplace=True),
nn.Dropout(0.3),
nn.Linear(feature_dim, feature_dim // 2),
nn.BatchNorm1d(feature_dim // 2),
nn.ReLU(inplace=True),
nn.Dropout(0.2),
nn.Linear(feature_dim // 2, num_classes)
)
def forward(self, x):
# 特征提取
features = self.backbone.forward_features(x)
# 应用注意力
features = self.se_block(features)
features = self.spatial_attention(features)
# 多尺度池化
global_feat = self.global_pool(features).flatten(1)
max_feat = self.max_pool(features).flatten(1)
# 特征融合
combined_feat = torch.cat([global_feat, max_feat], dim=1)
# 分类
output = self.classifier(combined_feat)
return output
# =================== 数据增强 ===================
def get_transforms(input_size=480, is_train=True): # 增大输入尺寸
if is_train:
return transforms.Compose([
transforms.RandomResizedCrop(input_size, scale=(0.75, 1.0)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.2), # 添加垂直翻转
transforms.RandomRotation(20),
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.15),
transforms.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.9, 1.1)),
transforms.RandomGrayscale(p=0.1), # 添加灰度化
transforms.RandomPerspective(distortion_scale=0.1, p=0.3), # 添加透视变换
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
transforms.RandomErasing(p=0.2, scale=(0.02, 0.2)) # 添加随机擦除
])
else:
return transforms.Compose([
transforms.Resize((input_size, input_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# =================== 优化的训练器类 ===================
class EnhancedFoodRecognitionTrainer:
def __init__(self, model, train_loader, val_loader, device, num_classes=10,
save_dir='checkpoints', save_interval=5, accumulation_steps=1): # 降低梯度累积步数
self.model = model.to(device)
self.train_loader = train_loader
self.val_loader = val_loader
self.device = device
self.num_classes = num_classes
self.accumulation_steps = accumulation_steps # 梯度累积步数
# 保存设置
self.save_dir = save_dir
self.save_interval = save_interval
# 创建保存目录
os.makedirs(save_dir, exist_ok=True)
# 使用混合精度训练
self.scaler = torch.cuda.amp.GradScaler() if device.type == 'cuda' else None
# 损失函数组合
self.criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
self.focal_loss = FocalLoss(alpha=1, gamma=2)
# 优化器 - 使用更大的学习率
self.optimizer = optim.AdamW(
model.parameters(),
lr=2e-3, # 增大学习率
weight_decay=1e-4,
eps=1e-8
)
# 学习率调度器
self.scheduler = optim.lr_scheduler.OneCycleLR(
self.optimizer,
max_lr=2e-3,
steps_per_epoch=len(train_loader) // accumulation_steps,
epochs=50,
pct_start=0.3,
div_factor=10,
final_div_factor=100
)
# 训练记录
self.train_losses = []
self.val_losses = []
self.val_accuracies = []
self.best_val_acc = 0.0
self.best_epoch = 0
self.start_time = datetime.now()
# 早停设置
self.patience = 15 # 增加patience
self.patience_counter = 0
self.early_stop = False
def save_checkpoint(self, epoch, is_best=False, is_interval=False):
"""保存检查点"""
checkpoint = {
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
'best_val_acc': self.best_val_acc,
'best_epoch': self.best_epoch,
'train_losses': self.train_losses,
'val_losses': self.val_losses,
'val_accuracies': self.val_accuracies,
'scaler_state_dict': self.scaler.state_dict() if self.scaler else None,
'training_time': str(datetime.now() - self.start_time)
}
if is_best:
best_path = os.path.join(self.save_dir, 'best_model.pth')
torch.save(checkpoint, best_path)
print(f'✓ Best model saved at epoch {epoch} with val_acc: {self.best_val_acc:.2f}%')
if is_interval:
interval_path = os.path.join(self.save_dir, f'checkpoint_epoch_{epoch:03d}.pth')
torch.save(checkpoint, interval_path)
print(f'✓ Checkpoint saved at epoch {epoch}')
latest_path = os.path.join(self.save_dir, 'latest_checkpoint.pth')
torch.save(checkpoint, latest_path)
self.save_training_log(epoch)
def save_training_log(self, epoch):
"""保存训练日志"""
log_data = {
'epoch': epoch,
'best_val_acc': self.best_val_acc,
'best_epoch': self.best_epoch,
'current_val_acc': self.val_accuracies[-1] if self.val_accuracies else 0,
'current_train_loss': self.train_losses[-1] if self.train_losses else 0,
'current_val_loss': self.val_losses[-1] if self.val_losses else 0,
'training_time': str(datetime.now() - self.start_time),
'device': str(self.device),
'patience_counter': self.patience_counter,
'memory_allocated': torch.cuda.memory_allocated(self.device) / 1024 ** 3 if self.device.type == 'cuda' else 0,
'memory_reserved': torch.cuda.memory_reserved(self.device) / 1024 ** 3 if self.device.type == 'cuda' else 0
}
log_path = os.path.join(self.save_dir, 'training_log.json')
with open(log_path, 'w') as f:
json.dump(log_data, f, indent=2)
def load_checkpoint(self, checkpoint_path):
"""加载检查点"""
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
if self.scaler and checkpoint.get('scaler_state_dict'):
self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
self.best_val_acc = checkpoint['best_val_acc']
self.best_epoch = checkpoint['best_epoch']
self.train_losses = checkpoint['train_losses']
self.val_losses = checkpoint['val_losses']
self.val_accuracies = checkpoint['val_accuracies']
start_epoch = checkpoint['epoch'] + 1
print(f'✓ Checkpoint loaded from epoch {checkpoint["epoch"]}')
print(f'✓ Best validation accuracy: {self.best_val_acc:.2f}% at epoch {self.best_epoch}')
return start_epoch
else:
print(f'✗ Checkpoint not found at {checkpoint_path}')
return 0
def cleanup_old_checkpoints(self, keep_last_n=3):
"""清理旧的检查点文件"""
checkpoint_files = [f for f in os.listdir(self.save_dir)
if f.startswith('checkpoint_epoch_') and f.endswith('.pth')]
if len(checkpoint_files) > keep_last_n:
checkpoint_files.sort()
for old_file in checkpoint_files[:-keep_last_n]:
old_path = os.path.join(self.save_dir, old_file)
os.remove(old_path)
print(f'✓ Removed old checkpoint: {old_file}')
def train_epoch(self):
"""训练一个epoch - 带梯度累积"""
self.model.train()
total_loss = 0.0
correct = 0
total = 0
self.optimizer.zero_grad()
pbar = tqdm(self.train_loader, desc='Training')
for batch_idx, (data, target) in enumerate(pbar):
data = data.to(self.device, non_blocking=True)
target = target.to(self.device, non_blocking=True)
# 混合精度训练
if self.scaler:
with torch.cuda.amp.autocast():
output = self.model(data)
# 组合损失
loss1 = self.criterion(output, target)
loss2 = self.focal_loss(output, target)
loss = 0.7 * loss1 + 0.3 * loss2
loss = loss / self.accumulation_steps # 缩放损失
self.scaler.scale(loss).backward()
# 梯度累积
if (batch_idx + 1) % self.accumulation_steps == 0:
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
self.scheduler.step()
else:
output = self.model(data)
loss1 = self.criterion(output, target)
loss2 = self.focal_loss(output, target)
loss = 0.7 * loss1 + 0.3 * loss2
loss = loss / self.accumulation_steps
loss.backward()
if (batch_idx + 1) % self.accumulation_steps == 0:
self.optimizer.step()
self.optimizer.zero_grad()
self.scheduler.step()
total_loss += loss.item() * self.accumulation_steps
_, predicted = output.max(1)
total += target.size(0)
correct += predicted.eq(target).sum().item()
if batch_idx % 10 == 0: # 每10个batch清理一次缓存
if self.device.type == 'cuda':
torch.cuda.empty_cache()
pbar.set_postfix({
'Loss': f'{loss.item() * self.accumulation_steps:.4f}',
'Acc': f'{100. * correct / total:.2f}%',
'LR': f'{self.scheduler.get_last_lr()[0]:.6f}',
'GPU': f'{torch.cuda.memory_allocated(self.device) / 1024 ** 3:.1f}GB'
})
avg_loss = total_loss / len(self.train_loader)
accuracy = 100. * correct / total
return avg_loss, accuracy
def validate(self):
"""验证模型"""
self.model.eval()
total_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for data, target in tqdm(self.val_loader, desc='Validation'):
data = data.to(self.device, non_blocking=True)
target = target.to(self.device, non_blocking=True)
if self.scaler:
with torch.cuda.amp.autocast():
output = self.model(data)
loss = self.criterion(output, target)
else:
output = self.model(data)
loss = self.criterion(output, target)
total_loss += loss.item()
_, predicted = output.max(1)
total += target.size(0)
correct += predicted.eq(target).sum().item()
avg_loss = total_loss / len(self.val_loader)
accuracy = 100. * correct / total
return avg_loss, accuracy
def train(self, epochs=50, resume_from=None):
"""训练模型"""
start_epoch = 0
if resume_from:
start_epoch = self.load_checkpoint(resume_from)
print(f"开始训练,共 {epochs} 轮")
print(f"使用设备: {self.device}")
print(f"混合精度训练: {'启用' if self.scaler else '禁用'}")
print(f"梯度累积步数: {self.accumulation_steps}")
print(f"有效批次大小: {self.train_loader.batch_size * self.accumulation_steps}")
print(f"保存目录: {self.save_dir}")
for epoch in range(start_epoch, epochs):
print(f'\nEpoch {epoch + 1}/{epochs}')
print('-' * 60)
# 训练
train_loss, train_acc = self.train_epoch()
# 验证
val_loss, val_acc = self.validate()
# 更新记录
self.train_losses.append(train_loss)
self.val_losses.append(val_loss)
self.val_accuracies.append(val_acc)
print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
# 检查是否为最佳模型
is_best = val_acc > self.best_val_acc
if is_best:
self.best_val_acc = val_acc
self.best_epoch = epoch
self.patience_counter = 0
else:
self.patience_counter += 1
# 保存检查点
is_interval = (epoch + 1) % self.save_interval == 0
self.save_checkpoint(epoch, is_best=is_best, is_interval=is_interval)
if is_interval:
self.cleanup_old_checkpoints(keep_last_n=3)
# 早停检查
if self.patience_counter >= self.patience:
print(f'\n早停触发!已经 {self.patience} 轮没有改善')
self.early_stop = True
break
# 监控GPU内存
if self.device.type == 'cuda':
memory_allocated = torch.cuda.memory_allocated(self.device) / 1024 ** 3
memory_reserved = torch.cuda.memory_reserved(self.device) / 1024 ** 3
print(f'GPU Memory: {memory_allocated:.1f}GB allocated, {memory_reserved:.1f}GB reserved')
print(f'Best Val Acc: {self.best_val_acc:.2f}% (Epoch {self.best_epoch + 1})')
print(f'Patience: {self.patience_counter}/{self.patience}')
# 训练完成
self.save_checkpoint(epoch, is_best=False, is_interval=True)
print(f'\n训练完成!')
print(f'最佳验证准确率: {self.best_val_acc:.2f}% (Epoch {self.best_epoch + 1})')
print(f'总训练时间: {datetime.now() - self.start_time}')
def plot_training_history(self):
"""绘制训练历史"""
plt.figure(figsize=(15, 5))
plt.subplot(1, 3, 1)
plt.plot(self.train_losses, label='Train Loss')
plt.plot(self.val_losses, label='Val Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.subplot(1, 3, 2)
plt.plot(self.val_accuracies, label='Val Accuracy')
plt.title('Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.grid(True)
plt.subplot(1, 3, 3)
plt.plot(range(len(self.train_losses)),
[self.scheduler.get_last_lr()[0] for _ in range(len(self.train_losses))])
plt.title('Learning Rate Schedule')
plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.grid(True)
plt.tight_layout()
plt.savefig('training_history.png', dpi=300)
plt.show()
# =================== 推理类 ===================
class FoodPredictor:
def __init__(self, model_path, class_names, device='cuda'):
self.device = device
self.class_names = class_names
# 加载模型
self.model = MultiscaleFoodNet(num_classes=len(class_names), model_name='efficientnet_b7')
checkpoint = torch.load(model_path, map_location=device)
if 'model_state_dict' in checkpoint:
self.model.load_state_dict(checkpoint['model_state_dict'])
else:
self.model.load_state_dict(checkpoint)
self.model.to(device)
self.model.eval()
# 数据预处理
self.transform = get_transforms(input_size=480, is_train=False)
def predict(self, image_path, top_k=3):
"""预测单张图片"""
image = Image.open(image_path).convert('RGB')
input_tensor = self.transform(image).unsqueeze(0).to(self.device)
with torch.no_grad():
with torch.cuda.amp.autocast() if self.device.type == 'cuda' else torch.no_grad():
output = self.model(input_tensor)
probabilities = F.softmax(output, dim=1)
top_probs, top_indices = torch.topk(probabilities, top_k)
results = []
for i in range(top_k):
class_name = self.class_names[top_indices[0][i].item()]
confidence = top_probs[0][i].item()
results.append((class_name, confidence))
return results
def predict_batch(self, image_paths):
"""批量预测"""
results = []
for image_path in image_paths:
result = self.predict(image_path)
results.append(result)
return results
# =================== 模型评估 ===================
def evaluate_model(model, test_loader, class_names, device='cuda'):
model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
with torch.cuda.amp.autocast() if device.type == 'cuda' else torch.no_grad():
for data, target in tqdm(test_loader, desc='Evaluating'):
data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)
output = model(data)
_, predicted = output.max(1)
all_preds.extend(predicted.cpu().numpy())
all_labels.extend(target.cpu().numpy())
accuracy = np.mean(np.array(all_preds) == np.array(all_labels))
print(f'Test Accuracy: {accuracy:.4f}')
report = classification_report(all_labels, all_preds, target_names=class_names)
print("\nClassification Report:")
print(report)
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=class_names, yticklabels=class_names)
plt.title('Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.tight_layout()
plt.savefig('confusion_matrix.png', dpi=300)
plt.show()
# =================== 主函数 ===================
def main():
# 设备配置
device = setup_device()
# 数据路径配置
train_data_dir = 'data/train'
val_data_dir = 'data/val'
test_data_dir = 'data/test'
# 检查数据目录是否存在
for data_dir in [train_data_dir, val_data_dir, test_data_dir]:
if not os.path.exists(data_dir):
print(f"警告: 数据目录 {data_dir} 不存在")
return
# 创建数据集
train_transform = get_transforms(input_size=480, is_train=True) # 增大输入尺寸
val_transform = get_transforms(input_size=480, is_train=False) # 增大输入尺寸
train_dataset = FoodDataset(train_data_dir, transform=train_transform)
val_dataset = FoodDataset(val_data_dir, transform=val_transform)
test_dataset = FoodDataset(test_data_dir, transform=val_transform)
# 优化的数据加载器
train_loader = get_optimized_dataloader(train_dataset, batch_size=256, shuffle=True) # 增大批量
val_loader = get_optimized_dataloader(val_dataset, batch_size=256, shuffle=False) # 增大批量
test_loader = get_optimized_dataloader(test_dataset, batch_size=256, shuffle=False) # 增大批量
# 获取类别信息
class_names = train_dataset.class_names
num_classes = len(class_names)
print(f'Number of classes: {num_classes}')
print(f'Classes: {class_names}')
print(f'Training samples: {len(train_dataset)}')
print(f'Validation samples: {len(val_dataset)}')
print(f'Test samples: {len(test_dataset)}')
# 创建模型
model = MultiscaleFoodNet(num_classes=num_classes, model_name='efficientnet_b7') # 升级模型
# 计算模型参数量
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Total parameters: {total_params:,}')
print(f'Trainable parameters: {trainable_params:,}')
# 创建保存目录
save_dir = 'checkpoints'
os.makedirs(save_dir, exist_ok=True)
# 训练模型
trainer = EnhancedFoodRecognitionTrainer(
model,
train_loader,
val_loader,
device,
num_classes,
save_dir=save_dir,
save_interval=5,
accumulation_steps=1 # 取消梯度累积以利用大batch
)
# 从检查点恢复或开始新训练
resume_path = os.path.join(save_dir, 'latest_checkpoint.pth')
if os.path.exists(resume_path):
print(f"发现检查点: {resume_path}")
resume_choice = input("是否从检查点恢复训练? (y/n): ").lower()
if resume_choice == 'y':
trainer.train(epochs=50, resume_from=resume_path)
else:
trainer.train(epochs=50)
else:
trainer.train(epochs=50)
# 绘制训练历史
trainer.plot_training_history()
# 评估模型
best_model_path = os.path.join(save_dir, 'best_model.pth')
checkpoint = torch.load(best_model_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
evaluate_model(model, test_loader, class_names, device)
# 推理示例
predictor = FoodPredictor(best_model_path, class_names, device)
# 单张图片预测示例
# results = predictor.predict('test_image.jpg')
# print("预测结果:")
# for class_name, confidence in results:
# print(f"{class_name}: {confidence:.4f}")
if __name__ == '__main__':
main()import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image, ImageTk
import torchvision.transforms as transforms
import timm
import os
import tkinter as tk
from tkinter import filedialog, ttk, messagebox
import numpy as np
# =================== 全局配置 ===================
# 这些参数需要与训练时的 pre5090.py 保持一致
NUM_COLOR_CLASSES = 20
NUM_INGREDIENT_CLASSES = 101
INPUT_SIZE = 480 # 确保这个尺寸与训练时一致
MODEL_NAME = 'convnext_large_in22k'
# **重要**:将模型路径指向您新的训练结果
MODEL_PATH = 'second_pre/checkpoints/best_model.pth'
DATA_DIR_FOR_CLASSES = 'data/train'
# =================== 中英文类名映射 ===================
CLASS_NAME_EN_TO_ZH = {
'apple_pie': '苹果派', 'baby_back_ribs': '烤肋排', 'baklava': '果仁蜜饼',
'beef_carpaccio': '生牛肉片', 'beef_tartare': '鞑靼牛肉', 'beet_salad': '甜菜沙拉',
'beignets': '法式甜甜圈', 'bibimbap': '石锅拌饭', 'bread_pudding': '面包布丁',
'breakfast_burrito': '早餐卷饼', 'bruschetta': '意式烤面包', 'caesar_salad': '凯撒沙拉',
'cannoli': '奶油甜馅煎饼卷', 'caprese_salad': '卡布里沙拉', 'carrot_cake': '胡萝卜蛋糕',
'ceviche': '酸橘汁腌鱼', 'cheesecake': '芝士蛋糕', 'cheese_plate': '奶酪拼盘',
'chicken_curry': '咖喱鸡', 'chicken_quesadilla': '墨西哥鸡肉饼', 'chicken_wings': '鸡翅',
'chocolate_cake': '巧克力蛋糕', 'chocolate_mousse': '巧克力慕斯', 'churros': '吉事果',
'clam_chowder': '蛤蜊浓汤', 'club_sandwich': '总会三明治', 'crab_cakes': '蟹饼',
'creme_brulee': '法式焦糖布丁', 'croque_madame': '法式太太三明治', 'cup_cakes': '纸杯蛋糕',
'deviled_eggs': '魔鬼蛋', 'donuts': '甜甜圈', 'dumplings': '饺子',
'edamame': '毛豆', 'eggs_benedict': '班尼迪克蛋', 'escargots': '法式焗蜗牛',
'falafel': '中东蔬菜球', 'filet_mignon': '菲力牛排', 'fish_and_chips': '炸鱼薯条',
'foie_gras': '鹅肝', 'french_fries': '炸薯条/炸薯片', 'french_onion_soup': '法式洋葱汤',
'french_toast': '法式吐司', 'fried_calamari': '炸鱿鱼', 'fried_rice': '炒饭',
'frozen_yogurt': '冻酸奶', 'garlic_bread': '蒜香面包', 'gnocchi': '意式面疙瘩',
'greek_salad': '希腊沙拉', 'grilled_cheese_sandwich': '烤奶酪三明治', 'grilled_salmon': '烤三文鱼',
'guacamole': '鳄梨酱', 'gyoza': '日式煎饺', 'hamburger': '汉堡',
'hot_and_sour_soup': '酸辣汤', 'hot_dog': '热狗', 'huevos_rancheros': '墨西哥牧场蛋',
'hummus': '鹰嘴豆泥', 'ice_cream': '冰淇淋', 'lasagna': '意式千层面',
'lobster_bisque': '龙虾浓汤', 'lobster_roll_sandwich': '龙虾三明治', 'macaroni_and_cheese': '奶酪通心粉',
'macarons': '马卡龙', 'miso_soup': '味增汤', 'mussels': '贻贝',
'nachos': '墨西哥玉米片', 'omelette': '煎蛋卷', 'onion_rings': '洋葱圈',
'oysters': '生蚝', 'pad_thai': '泰式炒河粉', 'paella': '西班牙海鲜饭',
'pancakes': '煎饼', 'panna_cotta': '意式奶昔布丁', 'peking_duck': '北京烤鸭',
'pho': '越南河粉', 'pizza': '披萨', 'pork_chop': '猪排',
'poutine': '肉汁奶酪薯条', 'prime_rib': '顶级肋排', 'pulled_pork_sandwich': '手撕猪肉三明治',
'ramen': '拉面', 'ravioli': '意大利方饺', 'red_velvet_cake': '红丝绒蛋糕',
'risotto': '意大利烩饭', 'samosa': '咖喱角', 'sashimi': '生鱼片',
'scallops': '扇贝', 'seaweed_salad': '海藻沙拉', 'shrimp_and_grits': '虾配玉米粥',
'spaghetti_bolognese': '意式肉酱面', 'spaghetti_carbonara': '培根蛋酱意面', 'spring_rolls': '春卷',
'steak': '牛排', 'strawberry_shortcake': '草莓蛋糕', 'sushi': '寿司',
'tacos': '塔可', 'takoyaki': '章鱼烧', 'tiramisu': '提拉米苏',
'tuna_tartare': '鞑靼金枪鱼', 'waffles': '华夫饼'
}
# =================== 模型定义 (从 pre5090.py 复制) ===================
class SEBlock(nn.Module):
"""Squeeze-and-Excitation Block"""
def __init__(self, channels, reduction=16):
super(SEBlock, self).__init__()
self.squeeze = nn.AdaptiveAvgPool2d(1)
self.excitation = nn.Sequential(
nn.Linear(channels, channels // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channels // reduction, channels, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.squeeze(x).view(b, c)
y = self.excitation(y).view(b, c, 1, 1)
return x * y
class SpatialAttention(nn.Module):
"""空间注意力模块"""
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x_cat = torch.cat([avg_out, max_out], dim=1)
attention = self.sigmoid(self.conv(x_cat))
return x * attention
class MultiTaskFoodNet(nn.Module):
def __init__(self, num_classes, num_colors, num_ingredients, model_name=MODEL_NAME, pretrained=False):
super(MultiTaskFoodNet, self).__init__()
self.backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=0)
with torch.no_grad():
dummy_input = torch.randn(1, 3, INPUT_SIZE, INPUT_SIZE)
features = self.backbone(dummy_input)
feature_dim = features.size(1)
self.se_block = SEBlock(feature_dim)
self.spatial_attention = SpatialAttention()
self.global_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.shared_classifier = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(feature_dim * 2, feature_dim),
nn.BatchNorm1d(feature_dim),
nn.ReLU(inplace=True),
nn.Dropout(0.3),
nn.Linear(feature_dim, feature_dim // 2),
nn.BatchNorm1d(feature_dim // 2),
nn.ReLU(inplace=True),
nn.Dropout(0.2)
)
self.classifier_main = nn.Linear(feature_dim // 2, num_classes)
self.classifier_color = nn.Linear(feature_dim // 2, num_colors)
self.classifier_ingredient = nn.Linear(feature_dim // 2, num_ingredients)
def forward(self, x):
features = self.backbone.forward_features(x)
features = self.se_block(features)
features = self.spatial_attention(features)
global_feat = self.global_pool(features).flatten(1)
max_feat = self.max_pool(features).flatten(1)
combined_feat = torch.cat([global_feat, max_feat], dim=1)
shared_features = self.shared_classifier(combined_feat)
output_main = self.classifier_main(shared_features)
output_color = self.classifier_color(shared_features)
output_ingredient = self.classifier_ingredient(shared_features)
return output_main, output_color, output_ingredient
# =================== 数据预处理 ===================
def get_transforms(input_size=INPUT_SIZE):
return transforms.Compose([
transforms.Resize((input_size, input_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# =================== 预测器类 ===================
class Predictor:
def __init__(self, model_path, class_names, device):
self.device = device
self.class_names = class_names
self.num_classes = len(class_names)
# 加载模型
self.model = MultiTaskFoodNet(
num_classes=self.num_classes,
num_colors=NUM_COLOR_CLASSES,
num_ingredients=NUM_INGREDIENT_CLASSES
)
try:
checkpoint = torch.load(model_path, map_location=device)
if 'model_state_dict' in checkpoint:
self.model.load_state_dict(checkpoint['model_state_dict'])
else:
self.model.load_state_dict(checkpoint)
self.model.to(device)
self.model.eval()
print("模型加载成功!")
except FileNotFoundError:
messagebox.showerror("错误", f"模型文件未找到: {model_path}")
raise
except Exception as e:
messagebox.showerror("错误", f"加载模型时发生错误: {e}\n\n这可能是因为模型结构与权重文件不匹配。请确保所有参数(如INPUT_SIZE)与训练时完全一致。")
raise
self.transform = get_transforms()
def predict(self, image_path, top_k=3):
try:
image = Image.open(image_path).convert('RGB')
input_tensor = self.transform(image).unsqueeze(0).to(self.device)
with torch.no_grad():
output_main, output_color, output_ingredient = self.model(input_tensor)
# 主菜品预测
probabilities_main = F.softmax(output_main, dim=1)
top_probs_main, top_indices_main = torch.topk(probabilities_main, top_k)
# 辅助任务预测 (颜色)
prob_color = F.softmax(output_color, dim=1)
color_conf, color_idx = torch.max(prob_color, 1)
# 辅助任务预测 (食材) - 多标签
prob_ingredient = torch.sigmoid(output_ingredient)
predicted_ingredients_indices = torch.topk(prob_ingredient, 5, dim=1).indices.squeeze(0).tolist()
main_results = []
for i in range(top_k):
class_name = self.class_names[top_indices_main[0][i].item()]
confidence = top_probs_main[0][i].item()
main_results.append((class_name, confidence))
# 注意:由于颜色和食材是虚拟标签,这里的预测结果也是演示性的
aux_results = {
"predicted_color_index": color_idx.item(),
"predicted_color_confidence": color_conf.item(),
"predicted_ingredient_indices": predicted_ingredients_indices
}
return main_results, aux_results
except Exception as e:
messagebox.showerror("预测错误", f"预测过程中发生错误: {e}")
return None, None
# =================== GUI 应用 ===================
class PredictionApp:
def __init__(self, root):
self.root = root
self.root.title("菜品识别预测2轮测试")
self.root.geometry("800x650")
style = ttk.Style()
style.configure("TButton", font=("Helvetica", 12), padding=10)
style.configure("TLabel", font=("Helvetica", 12))
style.configure("Result.TLabel", font=("Helvetica", 14, "bold"))
main_frame = ttk.Frame(root, padding="20 20 20 20")
main_frame.pack(expand=True, fill=tk.BOTH)
self.select_button = ttk.Button(main_frame, text="选择图片", command=self.select_image)
self.select_button.pack(pady=10)
self.image_label = ttk.Label(main_frame, text="请先选择一张图片")
self.image_label.pack(pady=10)
self.predict_button = ttk.Button(main_frame, text="开始预测", command=self.run_prediction, state=tk.DISABLED)
self.predict_button.pack(pady=10)
self.result_label = ttk.Label(main_frame, text="预测结果将显示在这里", style="Result.TLabel", justify=tk.LEFT, wraplength=700)
self.result_label.pack(pady=20)
self.image_path = None
try:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class_names = self.get_class_names()
if not class_names:
return
self.predictor = Predictor(MODEL_PATH, class_names, self.device)
except Exception:
self.root.destroy()
def get_class_names(self):
if not os.path.exists(DATA_DIR_FOR_CLASSES):
messagebox.showerror("错误", f"无法找到类别目录: {DATA_DIR_FOR_CLASSES}")
return None
return sorted([d for d in os.listdir(DATA_DIR_FOR_CLASSES) if os.path.isdir(os.path.join(DATA_DIR_FOR_CLASSES, d))])
def select_image(self):
file_path = filedialog.askopenfilename(filetypes=[("Image Files", "*.jpg *.jpeg *.png *.bmp")])
if file_path:
self.image_path = file_path
img = Image.open(self.image_path)
img.thumbnail((400, 400))
photo = ImageTk.PhotoImage(img)
self.image_label.config(image=photo, text="")
self.image_label.image = photo
self.predict_button.config(state=tk.NORMAL)
self.result_label.config(text="图片已选择,请点击“开始预测”")
def run_prediction(self):
if not self.image_path:
messagebox.showwarning("警告", "请先选择一张图片。")
return
self.predict_button.config(state=tk.DISABLED)
self.result_label.config(text="正在预测中,请稍候...")
self.root.update_idletasks()
main_results, aux_results = self.predictor.predict(self.image_path)
self.predict_button.config(state=tk.NORMAL)
if main_results:
result_text = "--- 菜品预测 (Top 3) ---\n\n"
for i, (class_name, confidence) in enumerate(main_results):
# 翻译为中文,如果找不到则显示英文
class_name_zh = CLASS_NAME_EN_TO_ZH.get(class_name, class_name.replace('_', ' ').title())
result_text += f"{i+1}. 菜品: {class_name_zh}\n 置信度: {confidence:.2%}\n\n"
result_text += "--- 辅助信息 (仅供参考) ---\n\n"
result_text += f"预测颜色 (索引): {aux_results['predicted_color_index']} (置信度: {aux_results['predicted_color_confidence']:.1%})\n"
result_text += f"预测食材 (Top 5 索引): {aux_results['predicted_ingredient_indices']}\n"
self.result_label.config(text=result_text)
else:
self.result_label.config(text="无法获取预测结果,请检查控制台输出。")
if __name__ == '__main__':
if not os.path.exists(MODEL_PATH):
messagebox.showerror("启动错误", f"模型文件 '{MODEL_PATH}' 不存在!\n请确保脚本与 'second_pre' 文件夹在同一目录下。")
else:
root = tk.Tk()
app = PredictionApp(root)
root.mainloop()