[Học nhanh] Giới thiệu về Pytorch ②: Tìm hiểu kiến ​​thức cơ bản về PyTorch bằng cách triển khai hồi quy MLP

Lần trước, chúng tôi đã xem xét cách xử lý ngọn đuốc, vốn là nền tảng của PyTorch.

Lần này, tôi muốn triển khai hồi quy Multilayer Perceptron (MLP) trong PyTorch và xem xét các gói PyTorch chính.

dòng chảy tổng thể

Luồng tổng thể của việc thực hiện mô hình như sau. () là gói PyTorch được sử dụng vào thời điểm đó.

  • Tạo tập dữ liệu đầu vào, lặp (Dataset, Dataloader)
  • Định nghĩa mạng thần kinh (nn.Module)
  • Tính toán tổn thất, truyền gradient tới các tham số mạng (nn.Module)
  • Cập nhật trọng số mạng (Trình tối ưu hóa)

Gói PyTorch chính

Bất kể PyTorch, nhiều khung học sâu cung cấp các gói khác nhau để đơn giản hóa việc triển khai. PyTorch chủ yếu có những điều sau đây:

ngọn đuốcMảng nhiều chiều. Cấu trúc dữ liệu được sử dụng bởi PyTorch.
ngọn đuốc.autogradThực hiện chuyển tiếp/backpropagation.Hỗ trợ các hoạt động phân biệt tự động trên Tensors như backpropagation (backward()).
ngọn đuốc.utils.dataNó bao gồm các tiện ích như "Bộ dữ liệu" thu thập dữ liệu đầu vào và nhãn của nó dưới dạng một bộ và "Trình tải dữ liệu" trích xuất dữ liệu từ Bộ dữ liệu theo lô nhỏ và chuyển dữ liệu đó tới mô hình.
ngọn đuốc.nn.ModuleĐược sử dụng để xây dựng mạng lưới thần kinh.Chịu trách nhiệm đóng gói các tham số như lưu và tải các mô hình và chuyển chúng sang GPU.
ngọn đuốc.optimCho phép sử dụng các thuật toán tối ưu hóa tham số như SDG và Adam.
Gói PyTorch chính

triển khai mô hình

Tạo tập dữ liệu

Lần này, chúng ta sẽ chuẩn bị sin(5x) cộng với một số ngẫu nhiên với dữ liệu thực hành là numpy.from_numpy()Chuyển đổi sang torch.tensor với

định nghĩa mô hình

Trong pytorch, một mô hình được định nghĩa là một "lớp trăn" kế thừa lớp nn.Module.

lớp MLP(nn.Module): Lớp MLP được định nghĩa kế thừa lớp cha nn.Module
def init(): nhận đối số và khởi tạo
super(MLP, self).init(): Kế thừa lớp cha với siêu hàm
def về phía trước (bản thân, x): Sau khi khởi tạo, nó hoạt động khi hàm được gọi. Xác định hàm tiến sẽ tự động xác định hàm lùi (tính toán độ dốc)

.parameters()Bạn có thể lấy cấu trúc và tham số mạng với

Tính toán tổn thất, lan truyền ngược, cập nhật trọng số

Để hiểu hành vi của từng loại, hãy lấy một phần dữ liệu từ x, đưa nó vào mạng thần kinh và xem các tham số thay đổi như thế nào do tính toán lỗi và cập nhật trọng số.

chạy vòng lặp học tập

Thực hiện quy trình trên cho mỗi đợt để huấn luyện mạng nơ-ron.
Datasettrả về một tập hợp dữ liệu và nhãn tương ứng, vàDataLoaderlà một lớp trả về dữ liệu ở kích thước hàng loạt.

Trực quan hóa đồ thị tính toán

Cấu trúc MLP ba lớp được tạo lần này có thể được hiển thị bằng cách sử dụng gói python có tên là torchviz.parameters()Hãy sử dụng nó khi điều đó một mình là không đủ.

Chúng tôi đã thấy PyTorch và các gói PyTorch chính của nó thông qua triển khai hồi quy MLP của chúng tôi.