アイキャッチ画像

PyTorch チュートリアルにトライ 1 (テンソル)

PyTorchへの入門として、公式のチュートリアルをなぞってみました。

目次

  1. PyTorch公式チュートリアル
  2. autogradパッケージ
  3. TensorとFunction
  4. 試してみた

PyTorch公式チュートリアル

PyTorchの公式サイトにチュートリアルがあります。今回は こちら を参考にしました。

本投稿では、ローカルにインストールしたPyTorchに対して、コマンドラインで直接実行しています。

autogradパッケージ

autogradというパッケージでテンソルの自動微分(Automatic differentiation)を行います。これがPyTorchのキモです。

TensorとFunction

touch.Tensorというクラスがこのパッケージの重要な部分です。requires_gradという属性をTureにすると、このテンソルの操作が記録されます。計算終了後にbackward()を実行すると、自動的に勾配が計算されます。勾配はgrad属性に蓄積されます。

テンソル計算の追跡を止めるには、detach()を実行します。

with touch.no_grad():というコードブロック内では、計算が追跡されなくなります。

もう一つ重要なクラスとしてFunctionがあります。

TensorとFunctionは内部接続により非循環グラフを構成します。各テンソルはgrad_fnという、Tensorを作成した際に使用されたFunctionを参照する属性を持っています。(ユーザーが直接作成したTensorのgrad_fn属性は、Noneになります。)

試してみた

まず、2x2で各要素が1のテンソルを作成してみます。

>>> import torch
>>> x = torch.ones(2, 2, requires_grad=True)
>>> print(x)
tensor([[1., 1.],
        [1., 1.]], requires_grad=True)

このテンソル x に2を足してみます。そして、計算結果のテンソル y のgrad_fn属性を表示してみます。

>>> import torch
>>> x = torch.ones(2, 2, requires_grad=True)
>>> y = x + 2
>>> print(y)
tensor([[3., 3.],
        [3., 3.]], grad_fn=<AddBackward0>)
>>> print(y.grad_fn)
<AddBackward0 object at 0x00000212A6886780>

grad_fn属性にAddがありますね。

テンソル y を使ってさらにテンソル z とその平均 out を計算してみます。

>>> import torch
>>> x = torch.ones(2, 2, requires_grad=True)
>>> y = x + 2
>>> z = y * y * 3
>>> print(z)
tensor([[27., 27.],
        [27., 27.]], grad_fn=<MulBackward0>)

>>> out = z.mean()
>>> print(out)
tensor(27., grad_fn=<MeanBackward0>)

計算内容によってgrad_fn属性が異なっているのがわかります。

バックプロパゲーションしてみます。

>>> import torch
>>> x = torch.ones(2, 2, requires_grad=True)
>>> y = x + 2
>>> z = y * y * 3
>>> out = z.mean()

>>> print(x.grad)
None
>>> out.backward()
>>> print(x.grad)
tensor([[4.5000, 4.5000],
        [4.5000, 4.5000]])

元のテンソル x のgrad属性に微分(d(out)/dx)が入りました。

requires_grad属性を変更してみます。

>>> import torch
>>> a = torch.randn(2,2)
>>> print(a)
tensor([[ 0.6139,  0.4427],
        [ 0.8860, -1.0564]])

>>> a = ((a * 3)/(a - 1))
>>> print(a)
tensor([[ -4.7701,  -2.3828],
        [-23.3140,   1.5412]])
>>> print(a.requires_grad)
False

>>> a.requires_grad_(True)
tensor([[ -4.7701,  -2.3828],
        [-23.3140,   1.5412]], requires_grad=True)
>>> print(a.requires_grad)
True

>>> b = (a * a).sum()
>>> print(b)
tensor(574.3494, grad_fn=<SumBackward0>)
>>> print(b.grad_fn)
<SumBackward0 object at 0x00000212A6886780>

計算の途中からrequires_grad属性がTrueになりました。

スカラーのテンソルに対しては、backwardの計算ができません。下記の例ではランタイムエラーになりました。

>>> import torch
>>> x = torch.randn(3, requires_grad=True)
>>> x
tensor([-0.5738, -0.2923,  2.0071], requires_grad=True)

>>> y = x * 2
>>> while y.data.norm() < 1000:
...     y = y * 2
...
>>> print(y)
tensor([-293.7970, -149.6342, 1027.6301], grad_fn=<MulBackward0>)
>>> print(x.grad)
None

>>> y.backward()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "C:\Program Files\Python37\lib\site-packages\torch\tensor.py", line 198, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "C:\Program Files\Python37\lib\site-packages\torch\autograd\__init__.py", line 94, in backward
    grad_tensors = _make_grads(tensors, grad_tensors)
  File "C:\Program Files\Python37\lib\site-packages\torch\autograd\__init__.py", line 35, in _make_grads
    raise RuntimeError("grad can be implicitly created only for scalar outputs")
RuntimeError: grad can be implicitly created only for scalar outputs

その場合には、backwardにテンソルを渡します。

>>> import torch
>>> x = torch.randn(3, requires_grad=True)
>>> y = x * 2
>>> while y.data.norm() < 1000:
...     y = y * 2
...
>>> print(x.grad)
None

>>> v = torch.tensor([0.1, 1.0, 0.0001], dtype=torch.float)
>>> v
tensor([1.0000e-01, 1.0000e+00, 1.0000e-04])

>>> y.backward(v)
>>> print(x.grad)
tensor([5.1200e+01, 5.1200e+02, 5.1200e-02])

touch.no_gradブロックを使ってみます。

>>> import torch
>>> x = torch.randn(3, requires_grad=True)
>>> x
tensor([-0.5738, -0.2923,  2.0071], requires_grad=True)
>>> print(x.requires_grad)
True
>>> print((x ** 2).requires_grad)
True

>>> with torch.no_grad():
...     print(x.requires_grad)
...     print((x ** 2).requires_grad)
...
True
False

torch.no_gradの中で計算したテンソルについては、requires_grad属性がFalseになりました。推論の際はこうすると良いのかもしれません。

公開日

広告