如何用预训练的.pth模型快速微调你的自定义数据集(ResNet50实战)

张开发
2026/4/19 17:38:00 15 分钟阅读

分享文章

如何用预训练的.pth模型快速微调你的自定义数据集(ResNet50实战)
ResNet50实战从预训练模型到自定义数据集的工业级微调指南在计算机视觉领域迁移学习已经成为解决实际业务问题的标准方法。当你的数据集只有几千张图片时从头训练一个深度神经网络几乎不可能获得好效果。这时预训练模型就像一位经验丰富的老师已经掌握了识别图像的基础能力只需要针对你的特定任务稍作调整。本文将带你深入ResNet50的微调过程不仅涵盖基础操作更会分享工业实践中验证过的技巧。1. 环境准备与模型加载工欲善其事必先利其器。在开始之前确保你的环境满足以下要求Python 3.7PyTorch 1.8torchvision 0.9CUDA 11.x如果使用GPUpip install torch torchvision torchaudio加载预训练ResNet50模型只需几行代码但有几个关键细节需要注意import torch import torchvision.models as models # 加载预训练模型不自动下载 model models.resnet50(pretrainedFalse) pretrained_dict torch.load(resnet50-0676ba61.pth) model.load_state_dict(pretrained_dict)提示直接从torchvision加载模型会下载权重到缓存目录。在生产环境中建议先下载.pth文件到本地避免因网络问题导致部署失败。ResNet50的原始分类层有1000个输出节点我们需要替换为自定义类别数num_classes 10 # 假设你的数据集有10类 model.fc torch.nn.Linear(model.fc.in_features, num_classes)2. 数据准备与增强策略数据准备是模型微调成功的关键。不同于从头训练微调时数据增强应该更保守from torchvision import transforms train_transform transforms.Compose([ transforms.Resize(256), transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness0.2, contrast0.2), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) val_transform transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ])对于自定义数据集推荐使用ImageFolder结构dataset/ train/ class1/ img1.jpg img2.jpg class2/ img1.jpg val/ class1/ img3.jpg class2/ img2.jpg加载数据时注意batch size的设置from torchvision.datasets import ImageFolder train_dataset ImageFolder(dataset/train, train_transform) val_dataset ImageFolder(dataset/val, val_transform) train_loader torch.utils.data.DataLoader( train_dataset, batch_size32, shuffleTrue, num_workers4) val_loader torch.utils.data.DataLoader( val_dataset, batch_size32, shuffleFalse, num_workers4)3. 模型微调的高级技巧3.1 分层学习率设置不同层应该使用不同的学习率。浅层提取基础特征应该用较小的学习率深层和分类层可以大一些optimizer torch.optim.SGD([ {params: model.conv1.parameters(), lr: 0.0001}, {params: model.layer1.parameters(), lr: 0.0001}, {params: model.layer2.parameters(), lr: 0.0005}, {params: model.layer3.parameters(), lr: 0.001}, {params: model.layer4.parameters(), lr: 0.005}, {params: model.fc.parameters(), lr: 0.01} ], momentum0.9)3.2 学习率调度策略结合余弦退火和热重启可以获得更好的收敛效果from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts scheduler CosineAnnealingWarmRestarts(optimizer, T_010, # 第一次周期长度 T_mult2, # 每次周期长度倍增 eta_min1e-6) # 最小学习率3.3 早停与模型保存实现一个简单的早停机制best_val_loss float(inf) patience 5 counter 0 for epoch in range(100): # 训练和验证代码... if val_loss best_val_loss: best_val_loss val_loss torch.save(model.state_dict(), best_model.pth) counter 0 else: counter 1 if counter patience: print(早停触发) break4. 模型评估与部署训练完成后在测试集上评估模型性能model.eval() correct 0 total 0 with torch.no_grad(): for images, labels in test_loader: images images.to(device) labels labels.to(device) outputs model(images) _, predicted torch.max(outputs.data, 1) total labels.size(0) correct (predicted labels).sum().item() print(f测试准确率: {100 * correct / total:.2f}%)对于生产部署建议将模型转换为TorchScript格式example_input torch.rand(1, 3, 224, 224).to(device) traced_script_module torch.jit.trace(model, example_input) traced_script_module.save(resnet50_scripted.pt)最后分享一个实际项目中的经验当数据集较小时冻结除最后一层外的所有参数先训练50个epoch然后解冻所有层用较小的学习率微调20个epoch这样通常能获得比直接微调更好的效果。

更多文章