PyTorch チュートリアルにトライ 3 (画像分類)
PyTorchへの入門として、公式のチュートリアルをなぞってみました。 PyTorch本家のチュートリアル に従って、CIFAR-10の画像分類をしてみたいと思います。
目次
CIFAR-10画像データのダウンロード
CIFAR-10というのはディープラーニングのベンチマークや入門書でよく扱われる画像データセットです。60,000枚の32x32のカラー画像が10クラスに分類されています。
PyTorchにはこのデータセットをダウンロードする機能があります。
import torch
import torchvision
import torchvision.transforms as transforms
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
このプログラムを実行すると、インターネット経由でデータをダウンロードします。
ダウンロード先はdatasetメソッドのroot属性に指定したディレクトリです。そのディレクトリにダウンロード済みであれば、再ダウンロードは行わず、ディレクトリ内のデータを読み込みます。
実行すると、1回目はこんな感じでデータをダウンロードします。
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\cifar-10-python.tar.gz
100%|██████████████████████████████████████▉| 170483712/170498071 [36:39<00:00, 62355.34it/s]Extracting ./data\cifar-10-python.tar.gz to ./data
Files already downloaded and verified
170500096it [36:44, 77340.39it/s]
データを表示してみる
ダウンロードしたデータは、画像データとして保存されてるのではなくPyTorchに便利な形のデータとして保存されます。画像単体ではなく、画像とラベルのセットですからね。
というわけで、ダウンロードしたデータの中からいくつか表示してみましょう。
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
def imshow(img):
img = img / 2 + 0.5 # 非正規化
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=0)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# 画像のピックアップ
dataiter = iter(trainloader)
images, labels = dataiter.next()
# 表示
imshow(torchvision.utils.make_grid(images))
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
Windows10で実行するとBrokenPipeErrorというエラーが出る場合があります。(私は出ました。) この場合は、DataLoaderメソッドのnum_workers属性の値を0に変更します。
実行すると画像が表示され、それを閉じると標準出力に下記のように各画像のラベルが表示されます。
deer bird plane frog
学習を実行する
本家のチュートリアルでは部分毎に分けてコードが記載さていますが、本投稿ではローカルで実行しますので、まとめたコードにします。
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# ニューラルネットワークの定義
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# データセット読み込み
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=0)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=0)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# ネットワークのインスタンス作成
net = Net()
# 損失関数とオプティマイザーの定義
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# 学習ループ
for epoch in range(2): # 総データに対する学習回数
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# データをリストに格納
inputs, labels = data
# パラメータを0にリセット
optimizer.zero_grad()
# 順方向の計算、損失計算、バックプロパゲーション、パラメータ更新
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# 計算状態の出力
running_loss += loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0
print('Finished Training')
# 計算結果のモデルを保存
torch.save(net.state_dict(), './cifar_net.pth')
最初のNetクラスで、ニューラルネットワークの定義をします。クラスのコンストラクターで層の定義をして、forwardメソッドに各層のつながりを定義します。
ニューラルネットワークの定義の後は、概ね下記のようにコードを書いています。
データセットの読み込み
ニューラルネットワークのインスタンス生成
損失関数の設定
オプティマイザーの設定
学習ループ
データセットから1バッチ分のデータを取り出し
パラメータの初期設定
順方向の計算
損失計算
バックプロパゲーション
パラメータ更新
モデルの保存
実行すると、出力はこうなります。
Files already downloaded and verified
Files already downloaded and verified
[1, 2000] loss: 2.244
[1, 4000] loss: 1.852
[1, 6000] loss: 1.677
[1, 8000] loss: 1.577
[1, 10000] loss: 1.516
[1, 12000] loss: 1.480
[2, 2000] loss: 1.397
[2, 4000] loss: 1.391
[2, 6000] loss: 1.364
[2, 8000] loss: 1.331
[2, 10000] loss: 1.334
[2, 12000] loss: 1.291
Finished Training
Surface Pro 6(Core i5-8250U)で、5分くらいで計算が終わります。
推論する
テストデータを使って推論してみます。
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
# ニューラルネットワークの定義
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def imshow(img):
img = img / 2 + 0.5
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
# データセット読み込み
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=0)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# データ取り出し
dataiter = iter(testloader)
images, labels = dataiter.next()
# 正解の表示
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))
# ネットワークモデル読み込み
net = Net()
net.load_state_dict(torch.load('./cifar_net.pth'))
# テストデータについて推論する
outputs = net(images)
_, predicted = torch.max(outputs, 1)
print('Predicted: ', ' '.join('%5s' % classes[predicted[j]] for j in range(4)))
概ね下記のような手順で推論します。
データセットの読み込み
ネットワークのインスタンス生成
モデルデータの読み込み
順方向計算
最も適合するものを選択
これを実行すると、推論の対象になる画像が表示され、正解と推論の結果が表示されます。
Files already downloaded and verified
GroundTruth: cat ship ship plane
Predicted: frog ship car plane
答えが合わないですね。
精度を計算する
全てのテストデータに対して推論を行い、モデルの精度を計算してみます。
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
# ニューラルネットワークの定義
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# データセット読み込み
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=0)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# ネットワークモデル読み込み
net = Net()
net.load_state_dict(torch.load('./cifar_net.pth'))
# 精度の計算
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))
これを実行するとこうなります。
Files already downloaded and verified
Accuracy of the network on the 10000 test images: 55 %
55%では、当てずっぽうとあまり変わりませんね。
分類ごとの精度を計算する
分類毎に得手不得手があるかもしれませんので、分類毎の精度を計算してみます。
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
# ニューラルネットワークの定義
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# データセット読み込み
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=0)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# ネットワークモデル読み込み
net = Net()
net.load_state_dict(torch.load('./cifar_net.pth'))
# 精度の計算
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs, 1)
c = (predicted == labels).squeeze()
for i in range(4):
label = labels[i]
class_correct[label] += c[i].item()
class_total[label] += 1
for i in range(10):
print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))
これを実行すると、出力はこうなります。
Files already downloaded and verified
Accuracy of plane : 60 %
Accuracy of car : 70 %
Accuracy of bird : 37 %
Accuracy of cat : 25 %
Accuracy of deer : 58 %
Accuracy of dog : 65 %
Accuracy of frog : 57 %
Accuracy of horse : 61 %
Accuracy of ship : 62 %
Accuracy of truck : 58 %
catが圧倒的に苦手ですね。
公開日
広告