问题提出:用pytorch训练VGG16分类,loss从0.69下降到0.24就开始小幅度震荡,不管如何调整batch_size和learning_rate都无法解决。

原因:没有加载预训练模型

那么问题来了,官方给出的是1000类的ImageNet预训练模型    
https://download.pytorch.org/models/vgg16-397923af.pth
<https://download.pytorch.org/models/vgg16-397923af.pth>
,而我要做的是20类数据集的分类,如何使用这一预训练的权重。
def vgg16(pretrained=False, **kwargs): """VGG 16-layer model (configuration
"D")""" model = VGG(make_layers(cfg['D']), **kwargs) if pretrained:
model.load_state_dict(torch.load('./vgg16-397923af.pth')) model.classifier =
nn.Sequential( nn.Linear(512 * 7 * 7, 4096), nn.ReLU(inplace=True),
nn.Dropout(), nn.Linear(4096, 4096), nn.ReLU(inplace=True), nn.Dropout(),
nn.Linear(4096, you_class_num), ) return model
其中VGG按照官方给出的构造方法构造class VGG即可。

先构造1000类的VGG模型,用于加载pth预训练模型,然后重新构造分类层,将最后一层全连接层设置为需要的类别数量即可。

友情链接
KaDraw流程图
API参考文档
OK工具箱
云服务器优惠
阿里云优惠券
腾讯云优惠券
华为云优惠券
站点信息
问题反馈
邮箱:ixiaoyang8@qq.com
QQ群:637538335
关注微信