Á¦ ¸ñ: ¼Õ±Û¾¾ Áغñ |
ÀÌ ¸§: ½Åº´Ãá |
ÀÛ¼ºÀÏÀÚ: 2019.07.09 - 15:45 |
# mnist_ready.py import os.path import numpy as np url_base = 'http://yann.lecun.com/exdb/mnist/' key_file = { 'train_img':'train-images-idx3-ubyte.gz', 'train_label':'train-labels-idx1-ubyte.gz', 'test_img':'t10k-images-idx3-ubyte.gz', 'test_label':'t10k-labels-idx1-ubyte.gz' } dataset_dir = os.path.dirname(os.path.abspath('__file__')) save_file = dataset_dir + "/Mnist_data/mnist_data.npz" # train_num = 60000; test_num = 10000; img_dim = (1, 28, 28); img_size = 784 def load_mnist(normalize=True, one_hot_label=False): """MNIST µ¥ÀÌÅͼ Àбâ load_mnist(normalize=True, one_hot_label=False): Dictionary with keys(['train_img', 'test_img', 'train_label', 'test_label']) """ if not os.path.exists(save_file): download_mnist() # Loading mnist dataset # print(save_file, 'is loaded to dataset \n') dataset = dict( np.load(save_file) ) # Making one-hot-label if needed if one_hot_label: dataset['train_label'] = one_hot_label_(dataset['train_label']) dataset['test_label'] = one_hot_label_(dataset['test_label']) # Normalizing if normalize: for key in ('train_img', 'test_img'): dataset[key] = dataset[key].astype(np.float32) dataset[key] /= 255.0 # print("Dictionary with", dataset.keys() ) return dataset def download_mnist(): import urllib.request, gzip D = {} for key in ('train_img', 'test_img'): fn = key_file[key] file_name = dataset_dir + "/Mnist_data/" + fn # gzip file if not os.path.exists(file_name): print("Downloading " + file_name + " ... ") urllib.request.urlretrieve(url_base + fn, file_name) print("Converting " + file_name + " to NumPy Array ...") with gzip.open(file_name, 'rb') as f: D[key] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1,img_size) for key in ('train_label', 'test_label'): fn = key_file[key] file_name = dataset_dir + "/Mnist_data/" + fn # gzip file if not os.path.exists(file_name): print("Downloading " + file_name + " ... ") urllib.request.urlretrieve(url_base + fn, file_name) print("Converting " + file_name + " to NumPy Array ...") with gzip.open(file_name, 'rb') as f: D[key] = np.frombuffer(f.read(), np.uint8, offset=8) save_npz = dataset_dir + "/Mnist_data/mnist_data" np.savez(save_npz, train_img=D['train_img'], test_img=D['test_img'], train_label=D['train_label'], test_label=D['test_label']) def one_hot_label_(X): T = np.zeros((X.size, 10)) idx = np.arange(X.size) T[idx, X] = 1 return T |