[Pembelajaran Cepat] Pengantar Pytorch ②: Pelajari dasar-dasar PyTorch dengan menerapkan regresi MLP

Terakhir kali, kami meninjau cara menangani obor, yang merupakan dasar dari PyTorch.

Kali ini saya ingin mengimplementasikan regresi Multilayer Perceptron (MLP) di PyTorch dan meninjau paket utama PyTorch.

aliran keseluruhan

Alur implementasi model secara keseluruhan adalah sebagai berikut. () adalah paket PyTorch yang digunakan saat itu.

  • Buat input dataset, ulangi (Dataset, Dataloader)
  • Definisi jaringan saraf (nn.Module)
  • Hitung kerugian, sebarkan gradien ke parameter jaringan (nn.Module)
  • Perbarui bobot jaringan (Pengoptimal)

Paket PyTorch utama

Terlepas dari PyTorch, banyak kerangka pembelajaran mendalam menyediakan berbagai paket untuk menyederhanakan implementasi. PyTorch terutama memiliki yang berikut:

tensor oborArray multidimensi. Struktur data yang digunakan oleh PyTorch.
obor.autogradTerapkan forward/backpropagation.Mendukung operasi diferensiasi otomatis pada Tensor seperti backpropagation (backward()).
obor.utils.dataIni termasuk utilitas seperti "Dataset" yang mengumpulkan data input dan labelnya sebagai satu set, dan "Dataloader" yang mengekstrak data dari Dataset dalam batch mini dan meneruskannya ke model.
torch.nn.ModulDigunakan untuk membangun jaringan saraf.Bertanggung jawab untuk mengenkapsulasi parameter seperti menyimpan dan memuat model dan memindahkannya ke GPU.
obor.optimMemungkinkan penggunaan algoritme pengoptimalan parameter seperti SDG dan Adam.
Paket PyTorch utama

Implementasi model

Buat kumpulan data

Kali ini kita akan menyiapkan sin(5x) plus bilangan acak dengan numpy sebagai data latihan.from_numpy()Konversikan ke torch.tensor dengan

Definisi model

Di pytorch, model didefinisikan sebagai "kelas python" yang mewarisi kelas nn.Module.

kelas MLP(nn.Modul): Kelas MLP yang ditentukan mewarisi kelas induk nn.Module
def init(): menerima argumen dan instantiate
super(MLP, mandiri).init(): Mewarisi kelas induk dengan fungsi super
def maju (diri, x): Setelah membuat instance, ini berfungsi saat fungsi dipanggil. Mendefinisikan fungsi maju secara otomatis mendefinisikan fungsi mundur (perhitungan gradien)

.parameters()Anda bisa mendapatkan struktur dan parameter jaringan dengan

Hitung kerugian, propagasi balik, perbarui bobot

Untuk memahami perilaku masing-masing, ambil sepotong data dari x, masukkan ke jaringan saraf, dan lihat bagaimana parameter berubah karena perhitungan kesalahan dan pembaruan bobot.

menjalankan lingkaran pembelajaran

Lakukan alur di atas untuk setiap batch untuk melatih jaringan saraf.
Datasetmengembalikan satu set data dan label yang sesuai, danDataLoaderadalah kelas yang mengembalikan data dalam ukuran batch.

Visualisasi grafik komputasi

Struktur MLP tiga lapis yang dibuat kali ini dapat divisualisasikan menggunakan paket python bernama torchviz.parameters()Silakan gunakan saat itu saja tidak cukup.

Kami telah melihat PyTorch dan paket PyTorch utamanya melalui implementasi regresi MLP kami.