什么蘑菇?
最后更新 2021/08/27 17:42
阅读 8576
CNN PyTorch 图像分类
黑羽
4
获得赞0
发布的文章2
答辩的项目Batch大小为1,循环次数为1次,通过在线上环境完成训练,模型最优精度评分为97.68。
最后更新 2021/08/27 17:42
阅读 8576
CNN PyTorch 图像分类
赛题背景:
本次赛题背景是日常生活中蘑菇类别较多,而其中包含了很多有毒蘑菇,人很难辨别,容易引起误食,该比赛想通过深度学习的方法帮助人们识别出毒蘑菇。本次竞赛使用的数据集由北欧真菌学家协会提供的9 种常见北欧蘑菇属的图像组成。
赛题分析:
这个题目就是一个图像分类问题,给定了6045张图片作为训练集,675张图片作为测试集,对于线下的10%数据。
实验过程:
1、5折交叉划分训练集和验证集,遍历每一个类别,训练与验证数据比例大致相同
2、试验了各种分类模型
3、预测的时候使用水平翻转求平均的简单策略
4、测试各种超参数 batchsize,学习率,增强策略,优化器,损失函数等
提分的一些操作:
1、标签平滑
2、模型选择 Swim-Transformer
3、数据集划分 9:1验证集有过拟合问题,8:2稍微好一些
4、模型融合
5、随机裁剪、随机擦除
6、tta
最终模型和参数:
(1)数据集划分: 训练集5折交叉 训练集:验证集=8:2
(2)模型选择: Swim-Transformer 预训练模型: swin_large_patch4_window12_384_22k.pth
(3)inputsize: 383*384
(4)batchsize: 4
(5)max_epochs: 30
(6)学习率: 阶梯下降学习率
def lr_scheduler(optimizer, epoch):
if epoch < 10:
lr = 1e-3
elif epoch < 20:
lr = 1e-4
else: lr = 1e-5
for param_group in optimizer.param_groups:
param_group['lr'] = lr
return optimizer
( 7)优化器:
self.optimizer = optim.SGD((self.model_ft.parameters()), lr=1e-3, momentum=momentum, weight_decay=0.0005) (8)损失函数:
LabelSmoothingCrossEntropy(smoothing=0.1).cuda()
(9)增强策略:
训练集: 随机裁剪、水平翻转、随机擦除
验证集:简单缩放
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop((input_size, input_size)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]),
'val': transforms.Compose([
transforms.Resize((input_size,input_size)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]),
(10)测试增强:
正常图片+ 水平翻转 -> 求平均
(11)融合策略:
保存每折训练过程中最大Acc和最小Loss的模型,数据集5折交叉,共5*2=10个模型进行融合
教训总结:
1、过早的模型融合,忽略了单个模型的性能比较
2、验证集过拟合问题
3、代码没有线下测试,导致线上出问题,浪费了很多实验时间
总结和展望:
1、对图片没有进行预处理等操作,数据分析还不足
2、多看看最新的研究,最后阶段才尝试使用最新的分类模型,对比之前的模型提分很明显
3、训练速度较慢,前排大神两次训练间隔才10几分钟就可以得到比较好的分数,我每次训练都要好几个小时,很想学习下,以后可以改进下训练策略
4、参加几次比赛,我也学到了很多东西,希望以后有更多的人来参与这个平台进行比赛,通过比赛和分享让大家共同进步。谢谢大家!
CNN PyTorch 图像分类
请先绑定您的微信账号 点击立即绑定
敬请谅解,如有疑问请联系FlyAI客服