ニューラルネットワークの損失関数

ニューラルネットワークの損失関数についてのメモです。

目次

  1. 損失関数
  2. 2乗和誤差
  3. 交差エントロピー誤差

損失関数

ニューラルネットワークの性能がどれだけ悪いのかを表す指標です。これを最小化することを狙って学習していきます。

2乗和誤差

2乗和誤差の式

yはニューラルネットワークの出力層の各ニューロンの出力です。活性化関数としてソフトマックス関数で処理されているものとします。

kは出力層のニューロンの数です。

tは教師データです。正解の要素だけに1を、それ以外の要素は0を設定します。

数式の形をみると、各要素の差を計算して2乗したものを合計しています。つまりこれは、各要素ごとに数値の距離を測って合計しているわけです。距離を計算していますので、正解から遠ければ遠いほど値が大きくなります。

例として、10要素の出力の2乗和誤差を計算してみます。

まず、教師データではインデックス番号2の要素が正解の場合に、ニューラルネットワークがインデックス2の確率が高いと判断した例です。

import matplotlib.pyplot as plt
import numpy as np

y = np.array([0.01, 0.1, 0.8, 0.03, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01])
t = np.array([0, 0, 1, 0, 0, 0, 0, 0, 0, 0])
e = 0.5 * np.sum((y - t)**2)

plt.plot(range(10), y)
plt.plot(range(10), t)
plt.ylim(-0.1,1.1)
plt.show()
print(e)
2乗和誤差のグラフ(正解)

赤が教師データ、青がニューラルネットワークの出力です。要素ごと(インデックス番号ごと)の数値の距離(違い)が少なくて、性能が良さそうです。2乗和誤差の値も0.0258と低い値になっています。

次に、ニューラルネットワークがインデックス番号6の確率が高いと判定した場合です。

import matplotlib.pyplot as plt
import numpy as np
y = np.array([0.02, 0.1, 0.2, 0.02, 0.01, 0.1, 0.5, 0.03, 0.01, 0.01])
t = np.array([0, 0, 1, 0, 0, 0, 0, 0, 0, 0])
e = 0.5 * np.sum((y - t)**2)
plt.plot(range(10), y)
plt.plot(range(10), t)
plt.ylim(-0.1,1.1)
plt.show()
print(e)
2乗和誤差のグラフ(不正解)

赤が教師データ、青がニューラルネットワークの出力です。要素ごと(インデックス番号ごと)の数値の差が先のものよりかなり大きいですね。このニューラルネットワークは明らかに先のものより性能が悪いということがわかります。2乗和誤差の値も0.456と先の物より大きくなりました。

このように、2乗和誤差は出力と正解の距離を計算しますので、数値が小さい方が性能が良いという指標になります。

交差エントロピー誤差

交差エントロピー誤差の式

yはニューラルネットワークの出力層の各ニューロンの出力です。活性化関数としてソフトマックス関数で処理されているものとします。

kは出力層のニューロンの数です。

tは教師データです。正解の要素だけに1を、それ以外の要素は0を設定します。

yはソフトマックス関数の出力ですから、0から1の間の値です。その場合にeを底とする対数は次のグラフになります。

対数のグラフ

また、tとyを掛けてますので、教師データが正解の要素以外は計算結果が0になります。つまり、教師データが正解のインデックス番号の要素だけを取り出して対数の計算をしているわけです。

2乗和誤差のときと同じデータで計算してみます。

まずニューラルネットワークの出力が正解に近い場合です。

import matplotlib.pyplot as plt
import numpy as np

y = np.array([0.01, 0.1, 0.8, 0.03, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01])
t = np.array([0, 0, 1, 0, 0, 0, 0, 0, 0, 0])
e = -np.sum(t * np.log(y + 1e-7))

plt.plot(range(10), y)
plt.plot(range(10), t)
plt.ylim(-0.1,1.1)
plt.show()
print(e)
交差エントロピー誤差のグラフ(正解)

次に、ニューラルネットワークの出力が正解から遠い場合です。

import matplotlib.pyplot as plt
import numpy as np

y = np.array([0.02, 0.1, 0.2, 0.02, 0.01, 0.1, 0.5, 0.03, 0.01, 0.01])
t = np.array([0, 0, 1, 0, 0, 0, 0, 0, 0, 0])
e = -np.sum(t * np.log(y + 1e-7))

plt.plot(range(10), y)
plt.plot(range(10), t)
plt.ylim(-0.1,1.1)
plt.show()
print(e)
交差エントロピー誤差のグラフ(不正解)

インデックス番号2の要素だけに着目すると、ニューロンの出力の値が小さい(性能が悪い)方が交差エントロピー誤差が大きくなりました。

交差エントロピー誤差は、正解のインデックスの要素の対数をとりますので、絶対値が小さいほど性能が良いという指標になります。

公開日

広告