问题提出:用pytorch训练VGG16分类,loss从0.69下降到0.24就开始小幅度震荡,不管如何调整batch_size和learning_rate都无法解决。
原因:没有加载预训练模型
那么问题来了,官方给出的是1000类的ImageNet预训练模型
https://download.pytorch.org/models/vgg16-397923af.pth
<https://download.pytorch.org/models/vgg16-397923af.pth>
,而我要做的是20类数据集的分类,如何使用这一预训练的权重。
def vgg16(pretrained=False, **kwargs): """VGG 16-layer model (configuration
"D")""" model = VGG(make_layers(cfg['D']), **kwargs) if pretrained:
model.load_state_dict(torch.load('./vgg16-397923af.pth')) model.classifier =
nn.Sequential( nn.Linear(512 * 7 * 7, 4096), nn.ReLU(inplace=True),
nn.Dropout(), nn.Linear(4096, 4096), nn.ReLU(inplace=True), nn.Dropout(),
nn.Linear(4096, you_class_num), ) return model
其中VGG按照官方给出的构造方法构造class VGG即可。
先构造1000类的VGG模型,用于加载pth预训练模型,然后重新构造分类层,将最后一层全连接层设置为需要的类别数量即可。
热门工具 换一换