Pytorch 에서의 학습 모델 저장하고 불러오기.

train model save & load

 

# Save

net = resnet.ResNet(110, 100).cuda()

train()

torch.save(net.state_dict(), './resnet110_cifar100.pth'))

 

# Load

net.load_state_dict(torch.load( './resnet110_cifar100.pth'))

test()

 

모델 구조가 같은데 파라미터들의 이름이 다른 경우.

model same but parameters name differ

 

net = resnet.ResNet(110, 100).cuda()

net.load_state_dict(torch.load( './resnet110_cifar100.pth'))


new_model = NewModel(net)

pretrained_dict = new_model.state_dict()
model_dict = net.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}

# update & load
model_dict.update(pretrained_dict)
net.load_state_dict(model_dict)

+ Recent posts