Pytorch 에서의 학습 모델 저장하고 불러오기.
train model save & load
# Save
net = resnet.ResNet(110, 100).cuda()
train()
torch.save(net.state_dict(), './resnet110_cifar100.pth'))
# Load
net.load_state_dict(torch.load( './resnet110_cifar100.pth'))
test()
모델 구조가 같은데 파라미터들의 이름이 다른 경우.
model same but parameters name differ
net = resnet.ResNet(110, 100).cuda()
net.load_state_dict(torch.load( './resnet110_cifar100.pth'))
new_model = NewModel(net)
pretrained_dict = new_model.state_dict()
model_dict = net.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# update & load
model_dict.update(pretrained_dict)
net.load_state_dict(model_dict)