利用PyTorch建構MNIST數字分類器

--

只要你接觸過深度學習(deep learning),那一定對MNIST資料集不陌生。MNIST是一個經典的手寫數字資料集,包含了60000張訓練圖片及10000張測試圖片。其中,每張圖片都是灰階的,且維度大小為28×28。

MNIST資料集的圖片樣本

現在,讓我們用PyTorch建構神經網路,並嘗試對MNIST數字進行分類吧!

首先,我們透過torchvision函式庫來下載訓練集和測試集,並分別存放在train_data及test_data。同時,我們可以利用len()查詢這兩者的樣本數,以確認結果符合預期。

來看看訓練集中的第一張圖片長什麼樣子:

接下來,就要來建構分類模型了。在開發神經網路時,很常會遇到內建模組無法滿足需求的情況。因此在本例中,我們將建立自己的nn.Module子類別。在自行建立的nn.Module子類別中,至少必須定義一個forward()函式來決定模型的運算過程。在使用PyTorch的情況下,autograd會自動處理反向計算的部分,因此nn.Module中無需定義任何的backward函式。

建構好模型後,就要來進行訓練了。在這裡,我們只設定進行100次的訓練。同時也將訓練過程中的各資訊(如:當前迴圈的預測準確率損失等)記錄下來,以便於稍後對結果進行視覺化。

讓我們把損失隨著訓練迴圈的變化給畫出來:

從圖中可見,損失在前10個訓練迴圈中波動很大,然後便逐漸收斂至一個很小的數值。這代表我們的訓練是成功的,接下來,同樣把訓練過程中的預測準確率變化給畫出來吧!

看來,我們的模型在訓練集中表現還不錯。在僅僅100次訓練後,就取得了97%的預測準確率。最後,我們要來看看該模型在未見過的資料(即測試資料集)上的表現。在以下程式中,小編從測試資料集中隨機選取了10張圖片,並讓模型針對這些圖片進行預測。結果如下:

我們的模型居然全部都預測成功了!有可能是模型剛好運氣不錯,因此不妨來看看模型在整個測試資料集(共10000張圖片)上的預測準確率有多少:

即使是在整個測試集上,模型的預測準確率也高達96%。由此可見,我們的模型並沒有發生過度配適(overfitting)的問題。至此,我們已經成功訓練出能夠有效分類MNIST數字的模型了。由於篇幅的限制,此處無法詳細說明PyTorch模型的內部細節,感興趣的讀者歡迎參考《核心開發者親授!PyTorch深度學習攻略》一書。本書為PyTorch官方唯一推薦教材,並且由PyTorch核心開發者所著,能以更全面的視角來進行教學。

以上的所有程式皆已整理在Colab筆記本中,鼓勵讀者可以自行嘗試調整訓練迴圈、批次大小等參數,看看能不能進一步增進模型的表現吧!

--

--

施威銘研究室

致力開發AI領域的圖書、創客、教具,希望培養更多的AI人才。整合各種人才,投入創客產品的開發,推廣「實作學習」,希望實踐學以致用的理想。