Chainerで数字を分類してみた
手入力した数字をChainerで分類してみました。
目次
Chainerとは
言わずと知れたPythonで動くディープラーニングフレームワークです。
今回は、tkinterでGUIを作って、マウスでキャンバスに描いた数字をChainerを使って分類してみました。
推論に使用するモデルは、 以前の投稿 で作ったものを流用しました。
ちょっと凝ったところとしては、各分類クラスのソフトマックスの出力をバーで表示するようにしてみました。
コード
import tkinter as tk
import tkinter.ttk as ttk
from PIL import Image, ImageDraw
import numpy as np
import chainer
from chainer.datasets import LabeledImageDataset
from chainer.serializers import load_npz
from chainer.links import Classifier
from chainercv.links import ResNet50
from chainercv.utils import write_image
class MyProgressBars(tk.Frame):
def __init__(self, parent):
tk.Frame.__init__(self, parent)
self.results = []
labels = []
progressbars = []
for i in range(10):
self.result = tk.DoubleVar()
self.results.append(self.result)
lb = ttk.Label(self, text=str(i))
lb.grid(row=i, column=0)
bar = ttk.Progressbar(self, variable=self.results[i], orient=tk.HORIZONTAL, mode='determinate', maximum=1.0)
bar.grid(row=i, column=1)
labels.append(lb)
progressbars.append(bar)
class Application(tk.Frame):
def __init__(self, master=None):
super().__init__(master)
self.master = master
self.master.title('chainer with tkinter trial')
self.pack()
self.create_widgets()
self.setup()
def create_widgets(self):
self.label_result = ttk.Label(self, text='Classified')
self.label_result.grid(row=0, column=1)
self.result_var = tk.StringVar()
self.label_result_var = ttk.Label(self, textvariable=self.result_var)
self.label_result_var.grid(row=1, column=1)
self.result_panel = MyProgressBars(self)
self.result_panel.grid(row=2, column=1)
self.clear_button = tk.ttk.Button(self, text='clear canvas', command=self.clear_canvas)
self.clear_button.grid(row=3, column=1)
self.input_canvas = tk.Canvas(self, bg='white', width=300, height=300)
self.input_canvas.grid(row=0, column=0, rowspan=4)
self.input_canvas.bind('<B1-Motion>', self.paint)
self.input_canvas.bind('<ButtonRelease-1>', self.reset)
def setup(self):
self.old_x = None
self.old_y = None
self.im = Image.new('RGB', (300, 300), 'black')
self.draw = ImageDraw.Draw(self.im)
self.result_var.set('')
def clear_canvas(self):
self.input_canvas.delete(tk.ALL)
self.im = Image.new('RGB', (300, 300), 'black')
self.draw = ImageDraw.Draw(self.im)
self.result_var.set('')
def paint(self, event):
if self.old_x and self.old_y:
self.input_canvas.create_line(self.old_x, self.old_y, event.x, event.y, width=30.0, fill='black', capstyle=tk.ROUND, smooth=tk.TRUE, splinesteps=36)
self.draw.line((self.old_x, self.old_y, event.x, event.y), fill='white', width=30)
self.old_x = event.x
self.old_y = event.y
def reset(self, event):
self.old_x, self.old_y = None, None
result = classification(self.im)
self.result_var.set(str(np.argmax(result)))
for i in range(10):
self.result_panel.results[i].set(result[i])
def classification(img):
# 画像をChainer用に変換
img = img.convert('L')
img = img.resize((28,28))
array_img = np.asarray(img).astype(np.float32)
# テストデータに対して推論を実行
x = array_img[np.newaxis, :, :]
with chainer.using_config('train', False), chainer.using_config('enable_backprop', False):
#y = model.predictor(x[None, ...]).data.argmax(axis=1)[0]
y = model.predictor(x[None, ...]).data
print(y[0])
return y[0]
# ネットワークのインスタンスを作る
extractor = ResNet50(n_class=10)
extractor.pick = 'prob'
model = Classifier(extractor)
load_npz('./my_resnet.model', model)
root = tk.Tk()
app = Application(master=root)
app.mainloop()
ChainerCVのResNet50を使う場合、extractor.pickをprobにすると、ソフトマックス関数を通した出力が得られます。
実際に動かしてみると、こうなります。
この映像は、CPU(Ryzen 2700X)で動作させたときのものです。ResNet50はそれなりに大きなモデルですが、まずますの応答をしてますね。
7とか8とかが間違ってますが、おそらく学習の不足です。10エポックしか回してませんからね。
公開日
広告
Chainerカテゴリの投稿
- ChainerCVで使える画像のデータ拡張
- ChainerCVで画像を出力する方法
- ChainerCVのResNetを使う
- ChainerCVのSSDに学習させてみた
- ChainerCVのデモンストレーションプログラムを読んでみた
- ChainerCVのデモンストレーションプログラムを読んでみた(推論編)
- Chainerが出力するネットワーク構造図をGraphvizで見る
- Chainerで数字を分類してみた
- ChainerのSSDのデモで物体検出をしてみる
- Chainerのチュートリアルを試してみた
- Chainerのチュートリアルを試してみた(ChainerCVでデータ拡張編)
- Chainerのチュートリアルを試してみた(データ拡張編)
- Chainerのチュートリアルを試してみた(トレーナー編)
- Chainerのチュートリアルを試してみた(畳み込みを深くする編)
- Chainerのチュートリアルを試してみた(畳み込み編)
- Chainerのデータセットの作り方(ラベル付き画像編)
- VoTTのPascal VOC出力をChainerCVのデータセットとして読み込んでみた