pytorch에서 Imbalance data set을 만들고 model을 동작했을때 오류.

ex) 10개의 클래스를 가진 데이터를 2개의 클래스만 사용하도록 imbalance한 데이터 셋을 만들었다. 모델의 마지막 fully connected layer의 경우 num_classes를 10에서 사용한 클래스 만큼2로 바꿔주었지만 에러가 발생함.

 

10개의 클래스가 있는 데이터에서 1번과 10번의 클래스를 사용한다고 했을때, target 데이터는 0,9의 값을 갖게됨. 여기서 문제가 발생. pytorch가 0~9개의 y값을 갖는다 생각 하는듯..?

y 값을 0,1로 바꿔주어야함. 즉 클래스 개수만큼 target 데이터는 0 ~ 클래수 갯수 만큼 1씩 증가시켜 줘야함.

ex) 4개의 데이터를 사용하면 target data = 0,1,2,3의 값으로 변경.

 

 

1. target values are not in the expected range of [0, num_classes].

   -> fixed num_classes 

2. change target data -> 0 ~ num_classes 

   if you have 10 classes data (target data is 0~9) but using 3 classes (target data 1,5,9)

   change target data  :   target data 1,5,9 -> target data is 0,1,2 

+ Recent posts