只要你接觸過深度學習(deep learning),那一定對MNIST資料集不陌生。MNIST是一個經典的手寫數字資料集,包含了60000張訓練圖片及10000張測試圖片。其中,每張圖片都是灰階的,且維度大小為28×28。
現在,讓我們用PyTorch建構神經網路,並嘗試對MNIST數字進行分類吧!
首先,我們透過torchvision函式庫來下載訓練集和測試集,並分別存放在train_data及test_data。同時,我們可以利用len()查詢這兩者的樣本數,以確認結果符合預期。
來看看訓練集中的第一張圖片長什麼樣子:
接下來,就要來建構分類模型了。在開發神經網路時,很常會遇到內建模組無法滿足需求的情況。因此在本例中,我們將建立自己的nn.Module子類別。在自行建立的nn.Module子類別中,至少必須定義一個forward()函式來決定模型的運算過程。在使用PyTorch的情況下,autograd會自動處理反向計算的部分,因此nn.Module中無需定義任何的backward函式。
建構好模型後,就要來進行訓練了。在這裡,我們只設定進行100次的訓練。同時也將訓練過程中的各資訊(如:當前迴圈的預測準確率及損失等)記錄下來,以便於稍後對結果進行視覺化。
讓我們把損失隨著訓練迴圈的變化給畫出來:
從圖中可見,損失在前10個訓練迴圈中波動很大,然後便逐漸收斂至一個很小的數值。這代表我們的訓練是成功的,接下來,同樣把訓練過程中的預測準確率變化給畫出來吧!
看來,我們的模型在訓練集中表現還不錯。在僅僅100次訓練後,就取得了97%的預測準確率。最後,我們要來看看該模型在未見過的資料(即測試資料集)上的表現。在以下程式中,小編從測試資料集中隨機選取了10張圖片,並讓模型針對這些圖片進行預測。結果如下:
我們的模型居然全部都預測成功了!有可能是模型剛好運氣不錯,因此不妨來看看模型在整個測試資料集(共10000張圖片)上的預測準確率有多少:
即使是在整個測試集上,模型的預測準確率也高達96%。由此可見,我們的模型並沒有發生過度配適(overfitting)的問題。至此,我們已經成功訓練出能夠有效分類MNIST數字的模型了。由於篇幅的限制,此處無法詳細說明PyTorch模型的內部細節,感興趣的讀者歡迎參考《核心開發者親授!PyTorch深度學習攻略》一書。本書為PyTorch官方唯一推薦教材,並且由PyTorch核心開發者所著,能以更全面的視角來進行教學。
以上的所有程式皆已整理在Colab筆記本中,鼓勵讀者可以自行嘗試調整訓練迴圈、批次大小等參數,看看能不能進一步增進模型的表現吧!