【속습】Pytorch 입문①:torch를 다루어 본다

2020/4/12

PyTorch란?

PyTorch는 Facebook이 개발한 딥 러닝 프레임워크입니다. TensorFlow나 keras에 비하면 이용자 인구는 적지만 유연한 네트워크 구축이 가능한 Define by run 형식을 특징으로 하여 지금 급속히 성장하고 있습니다.원래는 Chainer의 포크였지만, 선행하여 구미권의 연구자에게 널리 보급되었기 때문에, 현재는 PFN측이 Chainer를 중지해, PyTorch의 개발에 협력·합류하는 형태가 된 것 같습니다.

PyTorch는 다음 두 가지로 특징입니다.

・GPU에 의한 고속화가 가능한 NumPy에 상당하는 「torch」
· 유연하고 빠른 DefineByRun 타입의 심층 학습 플랫폼

PyTorch 설치 방법은 여기의 공식 페이지에서 :https://pytorch.org/
자세한 등은 많은 일본어 기사에서 다루고 있습니다 (참고 :PyTorch 입문!인기 급상승 중인 PyTorch에서 알아야 할 6가지 기초 지식)

Torch 사용법

Pytorch에서는 numpy형의 데이터를 입력해도 계산은 할 수 없고, torch.tensor라는 데이터형을 사용해 연산을 실시합니다.그 때문에 데이터는 torch.tensor형으로 작성·변환할 필요가 있습니다.이것은 거의 numpy와 같은 것이지만 Nvidia의 GPU에서 빠른 연산이 가능합니다.

또, torch 모듈내에는 다차원 텐서(고차의 행렬 같은 것)의 데이터 구조가 포함되어 있어 텐서의 계산이나 형 변환 등을 효율적으로 실시할 수 있습니다.

In [2]:
· .size()로 tensor 크기를 확인할 수 있습니다.
· 리스트의 슬라이스로 행렬의 요소를 꺼낼 수가 numpy 배열과 같이 취급할 수 있다.
In [3]: 가감잉여나 미분 등 기본적인 연산이 실행 가능
In [4]: .view()로 배열의 모양 변경
In [5]: numpy와 상호 변환 가능
In [6]: GPU에서 계산하려면 .to(device)로 전달