TIM Labs

28x28の手書き数字の場合のAutoEncoder

| コメント(0) | トラックバック(0)
MNISTの28x28の手書き数字は、現実の処理でも使うくらいのドット数がある。
これでAutoEndoderを作って、どのくらい再現性があるか調べてみよう。

まず、クラスMyAeのノード数の部分を書き換えよう。
中間層は、40ノードとする。つまり、748 --> 40 --> 748 と変換していこう。

class MyAE(Chain):
    def __init__(self):
        super(MyAE, self).__init__(
            l1=L.Linear(784,40),
            l2=L.Linear(40,784),
        )
変更はたったこれだけで、28x28の画像に対応できるようになる。
後は、書くだけだ。

最初に、データを読み込む。

# http://yann.lecun.com/exdb/mnist/
train, test = chainer.datasets.get_mnist()
xtrain = train._datasets[0]
ytrain = train._datasets[1]
xtest = test._datasets[0]
ytest = test._datasets[1]
今回は、学習の途中で、一定エポック毎にAutoEncoderの結果の最初の48枚の画像を1つの画像ファイルまとめて出力した。

# Learn
losslist = []
for j in range(1000000):
    x = Variable(xtrain[:10000])
    model.cleargrads()             # model.zerograds() 非推奨
    loss = model(x)
    if j%10000 == 9999:
        print( "%6d   %10.6f" % (j+1, loss.data) )
        xx = Variable(xtrain[:48], volatile='on')
        yy = model.fwd(xx)
        plotresults( yy, "mnistaeout/mnistae%d.png" % (j+1) )

    losslist.append(loss.data)     # 誤差をリストに追加
    loss.backward()
    optimizer.update()
10000エポック毎にスナップショット画像を吐き出して、全体で100万エポックまでやったのだが、このくらいやるとしっかり時間がかかり、走らせて結果は翌日確認することになった。(まだGPUは使っていない)

といことで、今回はプログラムの紹介だけで、結果は次回に示す。
今回のプログラムは:"minstae.py"

#!/usr/bin/env python
# from http://nlp.dse.ibaraki.ac.jp/~shinnou/book/chainer.tgz

import numpy as np
import chainer
from chainer import cuda, Function, gradient_check, Variable 
from chainer import optimizers, serializers, utils
from chainer import Link, Chain, ChainList
import chainer.functions as F
import chainer.links as L
import matplotlib.pyplot as plt

# http://yann.lecun.com/exdb/mnist/
train, test = chainer.datasets.get_mnist()
xtrain = train._datasets[0]
ytrain = train._datasets[1]
xtest = test._datasets[0]
ytest = test._datasets[1]

class MyAE(Chain):
    def __init__(self):
        super(MyAE, self).__init__(
            l1=L.Linear(784,40),
            l2=L.Linear(40,784),
        )
        
    def __call__(self,x):
        bv = self.fwd(x)
        return F.mean_squared_error(bv, x)
        
    def fwd(self,x):
        fv = F.sigmoid(self.l1(x))
        bv = self.l2(fv)
        return bv

def plotresults(yy,filename):
    fig,ax = plt.subplots(nrows=6,ncols=8,sharex=True,sharey=True)
    ax = ax.flatten()
    for i in range(48):
        img = yy[i].data.reshape(28,28)
        ax[i].imshow(img,cmap='Greys',interpolation='none')
    ax[0].set_xticks([])
    ax[0].set_yticks([])
    plt.tight_layout()
    plt.savefig(filename)
    plt.close()

# Initialize model        
model = MyAE()
optimizer = optimizers.SGD()
optimizer.setup(model)

# Learn
losslist = []
for j in range(1000000):
    x = Variable(xtrain[:10000])
    model.cleargrads()             # model.zerograds() 非推奨
    loss = model(x)
    if j%10000 == 9999:
        print( "%6d   %10.6f" % (j+1, loss.data) )
        xx = Variable(xtrain[:48], volatile='on')
        yy = model.fwd(xx)
        plotresults( yy, "mnistaeout/mnistae%d.png" % (j+1) )

    losslist.append(loss.data)     # 誤差をリストに追加
    loss.backward()
    optimizer.update()

トラックバック(0)

トラックバックURL: http://labs.timedia.co.jp/mt/mt-tb.cgi/583

コメントする

このブログ記事について

このページは、fujiが2017年3月25日 00:00に書いたブログ記事です。

ひとつ前のブログ記事は「データ解析のための統計モデリング入門 GLMの尤度比検定と検定の非対称性 読書メモ3」です。

次のブログ記事は「28x28の手書き数字の場合のAutoEncoderの出力画像」です。

最近のコンテンツはインデックスページで見られます。過去に書かれたものはアーカイブのページで見られます。