Chainerで数字を分類してみた

手入力した数字をChainerで分類してみました。

目次

  1. Chainerとは
  2. コード

Chainerとは

言わずと知れたPythonで動くディープラーニングフレームワークです。

今回は、tkinterでGUIを作って、マウスでキャンバスに描いた数字をChainerを使って分類してみました。

推論に使用するモデルは、 以前の投稿 で作ったものを流用しました。

ちょっと凝ったところとしては、各分類クラスのソフトマックスの出力をバーで表示するようにしてみました。

tkinterの使い方については このあたりの投稿 を、Chainerでの推論の仕方は こちら を参考にしてください。

コード

import tkinter as tk
import tkinter.ttk as ttk
from PIL import Image, ImageDraw
import numpy as np
import chainer
from chainer.datasets import LabeledImageDataset
from chainer.serializers import load_npz
from chainer.links import Classifier
from chainercv.links import ResNet50
from chainercv.utils import write_image

class MyProgressBars(tk.Frame):
    def __init__(self, parent):
        tk.Frame.__init__(self, parent)

        self.results = []
        labels = []
        progressbars = []
        for i in range(10):
            self.result = tk.DoubleVar()
            self.results.append(self.result)
            lb = ttk.Label(self, text=str(i))
            lb.grid(row=i, column=0)
            bar = ttk.Progressbar(self, variable=self.results[i], orient=tk.HORIZONTAL, mode='determinate', maximum=1.0)
            bar.grid(row=i, column=1)
            labels.append(lb)
            progressbars.append(bar)

class Application(tk.Frame):
    def __init__(self, master=None):
        super().__init__(master)
        self.master = master
        self.master.title('chainer with tkinter trial')
        self.pack()
        self.create_widgets()
        self.setup()

    def create_widgets(self):
        self.label_result = ttk.Label(self, text='Classified')
        self.label_result.grid(row=0, column=1)
        self.result_var = tk.StringVar()
        self.label_result_var = ttk.Label(self, textvariable=self.result_var)
        self.label_result_var.grid(row=1, column=1)

        self.result_panel = MyProgressBars(self)
        self.result_panel.grid(row=2, column=1)

        self.clear_button = tk.ttk.Button(self, text='clear canvas', command=self.clear_canvas)
        self.clear_button.grid(row=3, column=1)

        self.input_canvas = tk.Canvas(self, bg='white', width=300, height=300)
        self.input_canvas.grid(row=0, column=0, rowspan=4)
        self.input_canvas.bind('<B1-Motion>', self.paint)
        self.input_canvas.bind('<ButtonRelease-1>', self.reset)

    def setup(self):
        self.old_x = None
        self.old_y = None
        self.im = Image.new('RGB', (300, 300), 'black')
        self.draw = ImageDraw.Draw(self.im)
        self.result_var.set('')

    def clear_canvas(self):
        self.input_canvas.delete(tk.ALL)
        self.im = Image.new('RGB', (300, 300), 'black')
        self.draw = ImageDraw.Draw(self.im)
        self.result_var.set('')

    def paint(self, event):
        if self.old_x and self.old_y:
            self.input_canvas.create_line(self.old_x, self.old_y, event.x, event.y, width=30.0, fill='black', capstyle=tk.ROUND, smooth=tk.TRUE, splinesteps=36)
            self.draw.line((self.old_x, self.old_y, event.x, event.y), fill='white', width=30)
        self.old_x = event.x
        self.old_y = event.y

    def reset(self, event):
        self.old_x, self.old_y = None, None
        result = classification(self.im)
        self.result_var.set(str(np.argmax(result)))
        for i in range(10):
            self.result_panel.results[i].set(result[i])


def classification(img):

    # 画像をChainer用に変換
    img = img.convert('L')
    img = img.resize((28,28))
    array_img = np.asarray(img).astype(np.float32)

    # テストデータに対して推論を実行
    x = array_img[np.newaxis, :, :]
    with chainer.using_config('train', False), chainer.using_config('enable_backprop', False):
        #y = model.predictor(x[None, ...]).data.argmax(axis=1)[0]
        y = model.predictor(x[None, ...]).data
        print(y[0])
    return y[0]


# ネットワークのインスタンスを作る
extractor = ResNet50(n_class=10)
extractor.pick = 'prob'
model = Classifier(extractor)
load_npz('./my_resnet.model', model)

root = tk.Tk()
app = Application(master=root)
app.mainloop()

ChainerCVのResNet50を使う場合、extractor.pickをprobにすると、ソフトマックス関数を通した出力が得られます。

実際に動かしてみると、こうなります。

この映像は、CPU(Ryzen 2700X)で動作させたときのものです。ResNet50はそれなりに大きなモデルですが、まずますの応答をしてますね。

7とか8とかが間違ってますが、おそらく学習の不足です。10エポックしか回してませんからね。

広告

Chainerカテゴリの投稿