resnet RuntimeError: Given groups=1, weight of size [16, 3, 3, 3], expected input[128, 32, 32, 3] to have 3 channels, but got 32 channels instead 오류
Cifar 데이터를 Pytorch에서 제공하는 dataloader가 아닌 Custom data set으로 만들어 주었을때 위와 같은 오류가 발생.
보통의 경우 dataloader에서 나오는 데이터형식과 모델에서 받는 데이터 형식이 다르기 때문이다.
x 데이터의 형식을 torch.permute()로 변경해서 해결.
monetworkel의 forward의 단에서 수정을 해줌.
loader를 통해서 network에 input 변수(x data)가 들어감.
for i, (input, target) in enumerate(loader):
output = network(input)
model 정의 부분에 forward 부분에서 s.Szie를 해준 결과 [128,32,32,3]의 결과가 나옴.
pytorch에서 정의해준 데이터셋을 사용할 경우 [128,3,32,32]이 나오게 됨.
torch.permute()를 이용해서 데이터의 위치를 바꿔줘야함.
loader를 통해 나온 값은 [128,3,32,32] -> [배치사이즈,이미지채널,이미지가로,이미지세로]의 형식을 가져야한다.
def forward(self, x):
x = x.permute(0,3,1,2)
out = self.conv1(x)