TIM Labs

Chainer:学習結果(脳)をsaveする

| コメント(0) | トラックバック(0)
今回は、学習して、その結果をファイルにセーブする。

さて、学習結果なのだが、実際には model というDigitsChainクラスのオブジェクトの中に存在するので、このmodelをファイルにセーブすれば良いはずだ。
オブジェクトのセーブ、それもあれこれいっぱいデータが入っているはずのオブジェクトはどうやってセーブすれば良いだろうか?
Pythonには、オブジェクトをsave/loadするときに、オブジェクトを直列化してsaveするのが流儀であり、そのための仕組みが用意されている。

今回は、pickleを利用する。要するに、漬物(=pickle)にしてしまおうという訳だ。
詳しい説明はここではしないので、リンクを先を見るなどして自分で学習しよう。

import pickle
of = open('digitslearnt.pkl','wb')
pickle.dump(model,of)
of.close()
セーブは、たったこれだけで可能になる。 直列化してセーブしたデータの拡張子は".pkl"とする。
ファイル名とモード(バイナリの書き込みモード)でオープンし、そのファイルオブジェクトをofとしている。
次に、pickle.dumpで、セーブしたいオブジェクトと、ファイルオブジェクトを指定することで、ファイルに書き出してくれる。
最後にcloseする。

これだけだ。簡単過ぎる。

次回、別のプログラムで、この'digitslearnt.pkl'をロードしてテストしてみよう。
なので、テストデータもセーブしておこう。

of = open('digitstestdata.pkl','wb')
pickle.dump([xtest,yans],of)
of.close()
こんどは、2つのオブジェクトをセーブしないといけないので、リストにすることで、ひとつのオブジェクトに見えるようにした。 これで ".pkl" ファイルが2つできた。
Chainer$ ls -l *.pkl
-rw-rw-r-- 1 fuji fuji  20440  2月 15 15:56 digitslearnt.pkl
-rw-rw-r-- 1 fuji fuji 158385  2月 15 15:56 digitstestdata.pkl
今回は、ここまで。 次回、これらを読み込んで、テスト可能かどうか確かめよう。
"digits0s.py" : 学習し、セーブするところまで
#!/usr/bin/env python
# from http://nlp.dse.ibaraki.ac.jp/~shinnou/book/chainer.tgz

import numpy as np
import chainer
from chainer import cuda, Function, gradient_check, Variable
from chainer import optimizers, serializers, utils
from chainer import Link, Chain, ChainList
import chainer.functions as F
import chainer.links as L

# digitsデータの読み込み
from sklearn import datasets
digits = datasets.load_digits()
X = digits.data.astype(np.float32)
Y = digits.target.astype(np.int)
N = Y.size
Y2 = np.zeros(10 * N).reshape(N,10).astype(np.float32)
for i in range(N):
    Y2[i,Y[i]] = 1.0

# 学習データ(xtrain,ytrain)とテストデータ(xtest,yans)に分ける
index  = np.arange(N)
xtrain = X[index[index % 3 != 0],:]
ytrain = Y2[index[index % 3 != 0],:]
xtest  = X[index[index % 3 == 0],:]
yans   = Y[index[index % 3 == 0]]

# 学習モデルの初期化
from digitschain import DigitsChain
model = DigitsChain()
optimizer = optimizers.SGD()
optimizer.setup(model)

# 学習ループ

for i in range(10000):
    x = Variable(xtrain)
    y = Variable(ytrain)
    model.zerograds()
    loss = model(x,y)        # lossを求める (forward)
    loss.backward()        # 微分(backward)
    optimizer.update()        # 調整

# 学習内容(脳)をセーブする

import pickle
of = open('digitslearnt.pkl','wb')
pickle.dump(model,of)
of.close()

of = open('digitstestdata.pkl','wb')
pickle.dump([xtest,yans],of)
of.close()

トラックバック(0)

トラックバックURL: http://labs.timedia.co.jp/mt/mt-tb.cgi/567

コメントする

このブログ記事について

このページは、fujiが2017年2月24日 00:00に書いたブログ記事です。

ひとつ前のブログ記事は「データ解析のための統計モデリング入門 一般化線形モデル(GLM) 読書メモ3」です。

次のブログ記事は「データ解析のための統計モデリング入門 GLMのモデル選択 読書メモ」です。

最近のコンテンツはインデックスページで見られます。過去に書かれたものはアーカイブのページで見られます。