¸ÞÀϺ¸³»±â

À̸§°Ë»ö

::: Research Board :::


9 11 Åë°èÄ«¿îÅÍ º¸±â   ȸ¿ø °¡ÀÔ È¸¿ø ·Î±×ÀÎ °ü¸®ÀÚ Á¢¼Ó --+
Name   ½Åº´Ãá
Subject   ¼Õ±Û¾¾ Áغñ
# mnist_ready.py

import os.path
import numpy as np

url_base = ' http://yann.lecun.com/exdb/mnist/'
key_file = {
    'train_img':'train-images-idx3-ubyte.gz',
    'train_label':'train-labels-idx1-ubyte.gz',
    'test_img':'t10k-images-idx3-ubyte.gz',
    'test_label':'t10k-labels-idx1-ubyte.gz'
}

dataset_dir = os.path.dirname(os.path.abspath('__file__'))
save_file = dataset_dir + "/Mnist_data/mnist_data.npz"

# train_num = 60000; test_num = 10000; img_dim = (1, 28, 28);
img_size = 784

def load_mnist(normalize=True, one_hot_label=False):
    """MNIST µ¥ÀÌÅͼ Àбâ
       load_mnist(normalize=True, one_hot_label=False):
         Dictionary with keys(['train_img', 'test_img', 'train_label', 'test_label'])    
    """
    if not os.path.exists(save_file):
        download_mnist()
    # Loading mnist dataset    
    # print(save_file, 'is loaded to dataset \n')
    dataset = dict( np.load(save_file) )
    # Making one-hot-label if needed
    if one_hot_label:
        dataset['train_label'] = one_hot_label_(dataset['train_label'])
        dataset['test_label'] = one_hot_label_(dataset['test_label'])    
    # Normalizing
    if normalize:
        for key in ('train_img', 'test_img'):
            dataset[key] = dataset[key].astype(np.float32)
            dataset[key] /= 255.0
    # print("Dictionary with", dataset.keys() )
    return dataset

def download_mnist():
    import urllib.request, gzip
    D = {}
    for key in ('train_img', 'test_img'):
        fn = key_file[key]
        file_name = dataset_dir + "/Mnist_data/" + fn  # gzip file
        if not os.path.exists(file_name):
            print("Downloading " + file_name + " ... ")
            urllib.request.urlretrieve(url_base + fn, file_name)
        print("Converting " + file_name + " to NumPy Array ...")
        with gzip.open(file_name, 'rb') as f:
            D[key] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1,img_size)
                
    for key in ('train_label', 'test_label'):
        fn = key_file[key]
        file_name = dataset_dir + "/Mnist_data/" + fn  # gzip file
        if not os.path.exists(file_name):
            print("Downloading " + file_name + " ... ")
            urllib.request.urlretrieve(url_base + fn, file_name)
        print("Converting " + file_name + " to NumPy Array ...")
        with gzip.open(file_name, 'rb') as f:
            D[key] = np.frombuffer(f.read(), np.uint8, offset=8)

    save_npz = dataset_dir + "/Mnist_data/mnist_data"
    np.savez(save_npz, train_img=D['train_img'], test_img=D['test_img'],
                      train_label=D['train_label'], test_label=D['test_label'])

def one_hot_label_(X):
    T = np.zeros((X.size, 10))
    idx = np.arange(X.size)
    T[idx, X] = 1
    return T

                            

°Ô½Ã¹°À» À̸ÞÀÏ·Î º¸³»±â ÇÁ¸°Æ®Ãâ·ÂÀ» À§ÇÑ È­¸éº¸±â
DATE: 2019.07.09 - 15:45
LAST UPDATE: 2019.07.10 - 15:37


 ÀÌÀü±Û ¼Õ±Û¾¾ ÇнÀ
 ´ÙÀ½±Û MNIST tf ½ÇÇà ÆÄÀÏ
±Û³²±â±â»èÁ¦Çϱâ¼öÁ¤Çϱâ´äº¯´Þ±âÀüü ¸ñ·Ï º¸±â