MNIST Data 랜덤 추출.
MNIST DATA
- Imbalance data set을 만들기 위한 코드.
- 클래스 라벨과 갯수를 입력으로 받는 함수.
- 랜덤으로 데이터가 추출되며, 트레이닝 데이터셋을 만들시 리스트에 계속 추가 가능.
from tensorflow.examples.tutorials.mnist import input_data
import random
import numpy as np
import pandas as pd
# 데이터 추출
def Creat_Imbalance_Mnist_data(class_label, data_count,mnist_x,mnist_y):
# 데이터 생성 개수 체크.
count = 0
# 생성 데이터 저장할 리스트.
li_x_data = []
li_y_data = []
while count < data_count:
random_num = random.randint(1,49999)
if(mnist_y[random_num] == class_label):
li_x_data.append(mnist_x[random_num])
li_y_data.append(mnist_y[random_num])
count+=1
return li_x_data, li_y_data
# one-hot-encoding
def one_hot_encoding(label):
cls = set(label)
class_dict = {c: np.identity(len(cls))[i, :] for i, c in enumerate(cls)}
one_hot = np.array(list(map(class_dict.get, label)))
return one_hot
# Mnist 데이터 불러오기.
mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)
origin_data_x,origin_data_y = mnist.train.next_batch(50000)
# 라벨 순서 저장.
label = [0,1,2,3,4,5,6,7,8,9]
# 데이터 갯수.
count = [10,20,30,40,50,60,70,80,90,300]
# 함수 데이터를 받아올 리스트
train_data_x = []
train_data_y = []
# 데이터 추출
for i in range(10):
data_x, data_y = Creat_Imbalance_Mnist_data(label[i],count[i],origin_data_x,origin_data_y)
train_data_x.append(data_x)
train_data_y.append(data_y)
# x 차원 낮추기.
train_x = [a for i in train_data_x for a in i]
# one - hot _encoder
one_hot = [a for i in train_data_y for a in i]
train_y_one_hot = one_hot_encoding(one_hot)
print(len(train_x))
print(len(train_y_one_hot))