訓練資料不平衡,後果有多嚴重?

--

在深度學習領域中,訓練資料量往往是決定模型表現的關鍵。若你在處理二分類的任務,或許會遇到以下問題:某一類別的資料比另一類別容易取得,因此造成不同類別的訓練資料數量有較大落差。來看看以下的例子:

在肺癌偵測專案中,我們同時需要患癌族群健康族群的肺部CT圖來訓練模型,讓模型可以根據CT圖判斷某個人是否罹患肺癌。根據美國癌症協會的統計,男性罹患肺癌的機率約6.66%,而女性的機率則為5.88%。

換言之,患癌族群和健康族群的比例是不對等的。這也就導致了專家在收集訓練資料時,較難取得有惡性腫瘤的肺部CT圖。舉例來說,在收集的每100張肺部CT圖中,可能只有其中6張包含惡性腫瘤(假設為類別0),而剩餘的94張則都是健康組織(假設為類別1)。不同類別之訓練樣本數落差很大的問題,就稱作訓練資料不平衡

接下來,小編會以CIFAR-10資料集為例,說明訓練資料不平衡所造成的後果。若要下載CIFAR-10資料集,我們可以直接使用PyTorch的torchvision函式庫,如下所示:

該資料集中有10種類別的圖片,但我們只會抽出這兩個類別來進行分類(每個類別分別有5000張圖片)。在原始資料集中,貓屬於類別3,狗屬於類別5。為了方便處理,小編把它們的類別重新對應到0(代表貓)和1(代表狗)。

現在,訓練資料集中總共有10000張訓練圖片。為了模擬資料不平衡的狀況,必須調整一下現有的資料集。小編採取的做法為:只保留dog中(存有狗的圖片,見以下程式)的500張圖片,其餘的則全部拋棄。由此一來,資料集中只會剩下5500張圖片,其中5000張為貓,剩餘的500張為狗。換句話說,貓和狗的比例約為10:1。至此,我們已經成功模擬了資料不平衡的現象:

有了不平衡的資料集後,接著就來建立名為Net的模型類別。同時,定義訓練迴圈(training_loop)和驗證函式(validate),以方便接下來的訓練及檢視結果:

現在可以來實際運行訓練迴圈,此處設定訓練100次。在每次訓練迴圈後,程式會顯示該次迴圈的訓練損失、訓練準確率和驗證準確率等資訊,這有助於我們了解模型的表現。

在第一個訓練迴圈後,訓練準確率就高達了91%,而驗證準確率卻只有50%,這一點十分詭異。即使是經過100次訓練後,這兩項準確率也依舊沒有變動。回想一下,在現有的訓練資料集中,貓的圖片(5000張)約佔了所有圖片(5500張)的91%。有沒有注意到,這個數字剛好對上了之前的訓練準確率呢?

其實,我們的模型並沒有真正學到如何分辨貓和狗。它採取的做法,不過是把所有的圖片都預測為『貓』罷了!這也解釋了為什麼模型的驗證準確率一直維持在50%:由於驗證資料集中貓和狗的圖片數量是一樣的,而模型又會無腦地將所有圖片預測為貓,因此也就只有一半的圖片能正確預測。

這就是資料不平衡所造成的問題。在訓練初期,我們可大致將5500筆訓練圖片的預測結果分成以下幾類:

  1. 有一半貓的圖片(2500張)會預測正確
  2. 有一半貓的圖片(2500張)會預測錯誤
  3. 有一半狗的圖片(250張)會預測正確
  4. 有一半狗的圖片(250張)會預測錯誤

在以上4類結果中,對模型訓練有顯著貢獻的是第2類和第4類的結果。由於這些預測結果是錯誤的,因此會促使模型去調整參數,進而達到訓練的目的。這兩類結果會互相拉扯,防止神經網路陷入『只輸出單一種預測的狀態』。

不幸的是,由於第2類的數量是第4類的10倍之多,因此模型會更多地學習『如何避免把貓錯認成狗』。久而久之,模型開始『取巧』,於是直接把所有圖片都預測為『貓』。

為了解決這個問題,我們要想辦法平衡這兩個類別的圖片數量。此處要注意的是,雖然原始的CIFAR-10中已經有5000張狗的圖片,但我們想要模擬的是現實中的困境,所以便假設狗的圖片從始至終都只有500張。換句話說,我們要想辦法用這500張圖片來平衡資料集。

平衡資料集的方法有很多種,由於目前的圖片資料是存放在串列(list)中,因此我們可以用extend()將僅有的500張狗的圖片重複放進資料集中。那麼,要重複多少次呢?這由不同類別圖片之間的數量比例而定。以此例來說,貓的圖片數量為狗的10倍。因此,只要重複放置10次狗的圖片至資料集中,這兩種類別的圖片數量就都會是5000張了!

成功平衡資料集後,就要重新來嘗試訓練模型了。此處依舊設定運行100次訓練迴圈,讓我們看看成果如何:

從成果可見,訓練損失呈現了下降的趨勢,而訓練準確率和驗證準確率也在緩慢的上升中。在一開始,訓練準確率只有不到60%,這只比隨機猜類別來得好一些而已。經過50次的訓練後,訓練準確率已經達到了85%。雖然訓練準確率上升的較為緩慢,但這表示模型有在嘗試學習不同類別圖片的特徵,而非直接把所有圖片都預測為貓。

讀者可以嘗試增加訓練迴圈(程式皆已整理在Colab筆記本),看看準確率能不能進一步提升。由於我們在平衡資料時,只是一味的將dog中的圖片重複放進資料集,因此模型最後可能會發生過度配適的問題,即:模型死背這500張狗的圖片中的特徵,並沒有真的學會狗的體態特徵。

為了解決這個問題,在平衡資料時,我們或許可以嘗試使用資料擴增技術(data augmentation)。舉例來說,對這500張圖片進行隨機旋轉或平移,進而產生5000張和原始圖片有些微差異的訓練樣本,再放進資料集中。這樣一來,模型就較難去完全死背這5000張不同圖片的特徵了。

本文的重點為說明資料不平衡時會造成的後果,關於資料擴增的各種技巧,此處便不多加贅述。若讀者想了解更多使用PyTorch來擴增資料的技巧,歡迎參考《核心開發者親授!PyTorch深度學習攻略》一書。

開頭提過,在肺癌偵測的任務中,資料不平衡是很常見的事。除了肺癌患者的樣本數較少外,肺癌患者的CT圖本身就是不平衡資料的例子。在這些CT圖中,只有很少一部分是惡性腫瘤組織,其餘皆屬於健康組織。如果不進行資料平衡,模型的預測結果就會被大量的健康組織所主導,而發生『只輸出單一預測』的狀況,即預測所有的組織都是健康的。這個問題在肺癌偵測中是非常致命的,需要格外重視。

《核心開發者親授!PyTorch深度學習攻略》一書包含完整的肺部偵測專案,作者會根據現實中的難題提出許多實務上的技巧,其中就包括剛剛提到的資料平衡和資料擴增。在跟著實作整個專案後,你將收穫許多寶貴的經驗。只要你對深度學習有熱忱,並且已經具備一定的Python能力,那就非常歡迎你閱讀本書!

--

--

施威銘研究室
施威銘研究室

Written by 施威銘研究室

致力開發AI領域的圖書、創客、教具,希望培養更多的AI人才。整合各種人才,投入創客產品的開發,推廣「實作學習」,希望實踐學以致用的理想。

No responses yet