# mnist_ready.py
import os.path import numpy as np
url_base = ' http://yann.lecun.com/exdb/mnist/' key_file = { 'train_img':'train-images-idx3-ubyte.gz', 'train_label':'train-labels-idx1-ubyte.gz', 'test_img':'t10k-images-idx3-ubyte.gz', 'test_label':'t10k-labels-idx1-ubyte.gz' }
dataset_dir = os.path.dirname(os.path.abspath('__file__')) save_file = dataset_dir + "/Mnist_data/mnist_data.npz"
# train_num = 60000; test_num = 10000; img_dim = (1, 28, 28); img_size = 784
def load_mnist(normalize=True, one_hot_label=False): """MNIST µ¥ÀÌÅͼ Àбâ load_mnist(normalize=True, one_hot_label=False): Dictionary with keys(['train_img', 'test_img', 'train_label', 'test_label']) """ if not os.path.exists(save_file): download_mnist() # Loading mnist dataset # print(save_file, 'is loaded to dataset \n') dataset = dict( np.load(save_file) ) # Making one-hot-label if needed if one_hot_label: dataset['train_label'] = one_hot_label_(dataset['train_label']) dataset['test_label'] = one_hot_label_(dataset['test_label']) # Normalizing if normalize: for key in ('train_img', 'test_img'): dataset[key] = dataset[key].astype(np.float32) dataset[key] /= 255.0 # print("Dictionary with", dataset.keys() ) return dataset
def download_mnist(): import urllib.request, gzip D = {} for key in ('train_img', 'test_img'): fn = key_file[key] file_name = dataset_dir + "/Mnist_data/" + fn # gzip file if not os.path.exists(file_name): print("Downloading " + file_name + " ... ") urllib.request.urlretrieve(url_base + fn, file_name) print("Converting " + file_name + " to NumPy Array ...") with gzip.open(file_name, 'rb') as f: D[key] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1,img_size) for key in ('train_label', 'test_label'): fn = key_file[key] file_name = dataset_dir + "/Mnist_data/" + fn # gzip file if not os.path.exists(file_name): print("Downloading " + file_name + " ... ") urllib.request.urlretrieve(url_base + fn, file_name) print("Converting " + file_name + " to NumPy Array ...") with gzip.open(file_name, 'rb') as f: D[key] = np.frombuffer(f.read(), np.uint8, offset=8)
save_npz = dataset_dir + "/Mnist_data/mnist_data" np.savez(save_npz, train_img=D['train_img'], test_img=D['test_img'], train_label=D['train_label'], test_label=D['test_label'])
def one_hot_label_(X): T = np.zeros((X.size, 10)) idx = np.arange(X.size) T[idx, X] = 1 return T
|
|
LAST UPDATE: 2019.07.10 - 15:37 |
|