¸ÞÀϺ¸³»±â

À̸§°Ë»ö

::: Research Board :::


9 11 Åë°èÄ«¿îÅÍ º¸±â   ȸ¿ø °¡ÀÔ È¸¿ø ·Î±×ÀÎ °ü¸®ÀÚ Á¢¼Ó --+
Name   ½Åº´Ãá
Subject   ¼Õ±Û¾¾ ÇнÀ
# mnist_train.py
# ¼Õ±Û¾¾ ºÐ·ù ½Å°æ¸Á ÇнÀ

from ShinNet_relu import Inet1, Inet2  # Inet2(lr=0.5) => 97%
#from ShinNet_sigmoid import Inet1, Inet2 # Inet2(lr=0.1)=> 92%
from mnist_ready import load_mnist  # load dataset from mnist_ready.py

mnist = load_mnist(normalize=True, one_hot_label=False)
# group ³ª´©±â: X_train, t_train, X_test, t_test
X_train, t_train = mnist['train_img'], mnist['train_label']
X_test, t_test = mnist['test_img'], mnist['test_label']

img_size = X_train.shape[1]
N_clss = 10

batch_size = 100
N_epochs = 20

''' One layer NN '''
layer1 = 0
if layer1 == 1:
    Net = Inet1(i_size=img_size, o_size=N_clss)
    learning_rate = .02 # Inet1
    print("---- Classification for Mnist using One-layer ANN training. ---- ")
    for epoch in range(1, N_epochs+1):
        Net.train(X_train, t_train, learning_rate, batch_size)
        acc1 = Net.accuracy(X_train, t_train)
        acc2 = Net.accuracy(X_test, t_test)
        print('After {}-th epoch, Probability for Train({}), Test({})' \
              .format(epoch,acc1, acc2) )
        if (acc1+acc2)/2 > 0.97: break

print('')
layer2 = 1
if layer2 == 1:
    ''' Two layer NN '''
    Net = Inet2(i_size=img_size, h_size=100, o_size=N_clss)
    learning_rate = .5     # relu
#    learning_rate = 0.1     # sigmoid
    print("---- Classification for Mnist using Two-layer ANN training. ---- ")
    for epoch in range(1, N_epochs+1):
        Net.train(X_train, t_train, learning_rate, batch_size)
        acc1 = Net.accuracy(X_train, t_train)
        acc2 = Net.accuracy(X_test, t_test)
        print('After {}-th epoch, Probability for Train({}), Test({})' \
              .format(epoch,acc1, acc2) )
        if (acc1+acc2)/2 > 0.97: break

°Ô½Ã¹°À» À̸ÞÀÏ·Î º¸³»±â ÇÁ¸°Æ®Ãâ·ÂÀ» À§ÇÑ È­¸éº¸±â
DATE: 2019.07.09 - 15:46
LAST UPDATE: 2019.07.10 - 15:37


 ÀÌÀü±Û ½Å°æ¸Á ÆÄÀÏ: 1-layer(Inet1), 2-layer(Inet2)
 ´ÙÀ½±Û ¼Õ±Û¾¾ Áغñ
±Û³²±â±â»èÁ¦Çϱâ¼öÁ¤Çϱâ´äº¯´Þ±âÀüü ¸ñ·Ï º¸±â