Study/Deep Learning

MNIST Data 랜덤 추출.

MJ_DL 2018. 11. 15. 17:53

MNIST DATA

  • Imbalance data set을 만들기 위한 코드.
  • 클래스 라벨과 갯수를 입력으로 받는 함수.
  • 랜덤으로 데이터가 추출되며, 트레이닝 데이터셋을 만들시 리스트에 계속 추가 가능.


from tensorflow.examples.tutorials.mnist import input_data

import random

import numpy as np

import pandas as pd


# 데이터 추출

def Creat_Imbalance_Mnist_data(class_label, data_count,mnist_x,mnist_y):

    # 데이터 생성 개수 체크.

    count = 0

    # 생성 데이터 저장할 리스트.

    li_x_data = []

    li_y_data = []


    while count < data_count:

        random_num = random.randint(1,49999)

        

        if(mnist_y[random_num] == class_label):

            li_x_data.append(mnist_x[random_num])

            li_y_data.append(mnist_y[random_num])

            count+=1

            

    return li_x_data, li_y_data


# one-hot-encoding

def one_hot_encoding(label):

    cls = set(label)

    class_dict = {c: np.identity(len(cls))[i, :] for i, c in enumerate(cls)}

    one_hot = np.array(list(map(class_dict.get, label)))

    

    return one_hot


# Mnist 데이터 불러오기.

mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)


origin_data_x,origin_data_y =  mnist.train.next_batch(50000)


# 라벨 순서 저장.

label = [0,1,2,3,4,5,6,7,8,9]

# 데이터 갯수.

count = [10,20,30,40,50,60,70,80,90,300]


# 함수 데이터를 받아올 리스트

train_data_x = []

train_data_y = []


# 데이터 추출

for i in range(10):

    data_x, data_y = Creat_Imbalance_Mnist_data(label[i],count[i],origin_data_x,origin_data_y)

    

    train_data_x.append(data_x)

    train_data_y.append(data_y)


# x 차원 낮추기.

train_x = [a for i in train_data_x for a in i]

# one - hot _encoder

one_hot = [a for i in train_data_y for a in i]

train_y_one_hot = one_hot_encoding(one_hot)


print(len(train_x))

print(len(train_y_one_hot))