【速習】Pytorch入門①:torchを扱ってみる

2020年4月12日

PyTorchとは

PyTorchとは、Facebookが開発したディープラーニングフレームワークです。TensorFlowやkerasに比べると利用者人口は少ないですが、柔軟なネットワーク構築が可能なDefine by run形式を特徴とし、今急速に成長しています。もともとはChainerのforkでしたが、先行して欧米圏の研究者に広く普及したことから、現在はPFN側がChainerを中止し、PyTorchの開発に協力・合流する形となったようです。

PyTorchは以下の2つと特徴としています。

・GPUによる高速化が可能なNumPyに相当する「torch」
・柔軟性で高速なDefineByRun型の深層学習プラットフォーム

PyTorchのインストール方法はこちらの公式ページから:https://pytorch.org/
詳細などは多くの日本語記事にて取り上げられています(参考:PyTorch 入門!人気急上昇中のPyTorchで知っておくべき6つの基礎知識

Torchの使い方

Pytorchでは、numpy型のデータを入力しても計算はできず、torch.tensorというデータ型を使って演算を行います。そのためデータはtorch.tensor型で作成・変換する必要があります。これは、ほぼnumpyのようなものですが、NvidiaのGPUで高速な演算が可能です。

また、torchモジュール内には多次元テンソル(高次の行列みたいなもの)のデータ構造が含まれており、テンソルの計算や型変換などを効率的に行うことができます。

In [2]:
・ .size()でtensorサイズを確認できる
・ リストのスライスで行列の要素を取り出すことができ、numpy配列と同様に扱える。
In [3]: 加減剰余や微分など基本的な演算が実行可能
In [4]: .view()で配列の形状を変更
In [5]: numpyと相互変換できる
In [6]: GPU上で計算させるには、.to(device)で渡す