[Pembelajaran Cepat] Pengantar Pytorch ②: Pelajari dasar-dasar PyTorch dengan menerapkan regresi MLP
Terakhir kali, kami meninjau cara menangani obor, yang merupakan dasar dari PyTorch.
Kali ini saya ingin mengimplementasikan regresi Multilayer Perceptron (MLP) di PyTorch dan meninjau paket utama PyTorch.
aliran keseluruhan
Alur implementasi model secara keseluruhan adalah sebagai berikut. () adalah paket PyTorch yang digunakan saat itu.
- Buat input dataset, ulangi (Dataset, Dataloader)
- Definisi jaringan saraf (nn.Module)
- Hitung kerugian, sebarkan gradien ke parameter jaringan (nn.Module)
- Perbarui bobot jaringan (Pengoptimal)
Paket PyTorch utama
Terlepas dari PyTorch, banyak kerangka pembelajaran mendalam menyediakan berbagai paket untuk menyederhanakan implementasi. PyTorch terutama memiliki yang berikut:
tensor obor | Array multidimensi. Struktur data yang digunakan oleh PyTorch. |
obor.autograd | Terapkan forward/backpropagation.Mendukung operasi diferensiasi otomatis pada Tensor seperti backpropagation (backward()). |
obor.utils.data | Ini termasuk utilitas seperti "Dataset" yang mengumpulkan data input dan labelnya sebagai satu set, dan "Dataloader" yang mengekstrak data dari Dataset dalam batch mini dan meneruskannya ke model. |
torch.nn.Modul | Digunakan untuk membangun jaringan saraf.Bertanggung jawab untuk mengenkapsulasi parameter seperti menyimpan dan memuat model dan memindahkannya ke GPU. |
obor.optim | Memungkinkan penggunaan algoritme pengoptimalan parameter seperti SDG dan Adam. |
Implementasi model
Buat kumpulan data
Kali ini kita akan menyiapkan sin(5x) plus bilangan acak dengan numpy sebagai data latihan.from_numpy()
Konversikan ke torch.tensor dengan
Definisi model
Di pytorch, model didefinisikan sebagai "kelas python" yang mewarisi kelas nn.Module.
kelas MLP(nn.Modul): Kelas MLP yang ditentukan mewarisi kelas induk nn.Module
def init(): menerima argumen dan instantiate
super(MLP, mandiri).init(): Mewarisi kelas induk dengan fungsi super
def maju (diri, x): Setelah membuat instance, ini berfungsi saat fungsi dipanggil. Mendefinisikan fungsi maju secara otomatis mendefinisikan fungsi mundur (perhitungan gradien)
.parameters()
Anda bisa mendapatkan struktur dan parameter jaringan dengan
Hitung kerugian, propagasi balik, perbarui bobot
Untuk memahami perilaku masing-masing, ambil sepotong data dari x, masukkan ke jaringan saraf, dan lihat bagaimana parameter berubah karena perhitungan kesalahan dan pembaruan bobot.
menjalankan lingkaran pembelajaran
Lakukan alur di atas untuk setiap batch untuk melatih jaringan saraf.Dataset
mengembalikan satu set data dan label yang sesuai, danDataLoader
adalah kelas yang mengembalikan data dalam ukuran batch.
Visualisasi grafik komputasi
Struktur MLP tiga lapis yang dibuat kali ini dapat divisualisasikan menggunakan paket python bernama torchviz.parameters()
Silakan gunakan saat itu saja tidak cukup.
Kami telah melihat PyTorch dan paket PyTorch utamanya melalui implementasi regresi MLP kami.
diskusi
Daftar komentar
Belum ada komentar