想必很多人知道在二元分類的應用裡,Cross Entropy Loss是相當常用的損失函數,然而在人像偵測這樣二元分類的應用中,使用Cross Entropy Loss有時並沒有辦法訓練出好的模型,今天就一起來好好探討Cross Entropy Loss在物件偵測中可能的問題吧。
一、Cross Entropy
假設我們的二元分類問題,標籤分別為0跟1,並且我們的模型輸出為「某筆資料屬於標籤1的機率是多少」,則我們可以使用Cross Entropy Loss作為訓練模型時的損失函數:
y為某一筆資料的標籤,p為此資料為標籤1的機率。我們可以轉換模型的輸出為「某筆資料屬於標籤的機率是多少」,進一步化簡損失函數:
當我們算出每一筆資料的Cross Entropy Loss之後,可以全部加起來取平均,即可得到所有資料的Cross Entropy Loss。
二、Balanced Cross Entropy
上述的Cross Entropy Loss看起來很簡單、很好用,但是當我們面對不平衡的資料集,比如標籤為0的資料有10000筆,標籤為1的資料有10筆,這樣模型可能直接把所有資料分類到標籤0就好了,顯然這樣的結果並不是我們期待。
Balanced Cross Entropy即是要解決這個問題!計算每一筆資料的Cross Entropy Loss之後,再乘上一個Weighting Factor,來調整每一個Cross Entropy Loss的重要性:
以上述的範例,我們可以設定標籤0的資料的Weighting Factor為10,標籤1的資料的Weighting Factor為10000。算出每一筆資料的Balanced Cross Entropy Loss之後,再做加權平均,即可得到所有資料的Balanced Cross Entropy Loss。
關於不平衡資料集的處理,除了透過調整損失函數之外,還有一些其他方法,比如Sampling,詳細技術可以參考旗標出版的「Kaggle競賽攻頂秘笈 — 揭開Grandmaster的特徵工程心法,掌握制勝的關鍵技術」。
三、Focal Loss
看起來Balanced Cross Entropy Loss已經很完美了呀!但是我們接下來要介紹的情況,可能還有調整的空間。
假設今天屬於標籤0的數量,比標籤1的數量多非常多,而且判斷某個物件是標籤0是非常非常容易,這種問題我們稱為Easy/Hard Examples。也就是說,這類問題的特徵是「不平衡資料集,而且量大的類別還超級好辨識」。實際應用中,像是YOLO物件偵測演算法就是屬於這類問題的其中一個:大部分的影像,物件(標籤為1)的量非常非常少,而且辨識背景實在太容易了。在這種情況下,可能還是會因為有大量簡單且好辨識的背景,使得模型訓練的時候使用Balanced Cross Entropy Loss,可能也無法針對稀少且難以辨識的物件做有效率的優化。接下來我們就要來探討一個研究團隊提出來的解決方案:Focal Loss (Lin et al., 2018)。
先來看看Focal Loss長什麼樣子:
此研究團隊稱公式中的alpha為Weighting Factor,gamma為Modulating Factor。這數學看起來…有點奇怪,接下來就來好好解析一下這個損失函數的特性。首先,如果Modulating Factor是0,Focal Loss就會變成Imbalanced Cross Entropy Loss;如果Modulating Factor是0且Weighting Factor是1,Focal Loss就會變成Cross Entropy Loss。所以讀者可以想像Focal Loss是更通用的損失函數。
接下來,我們代入一些值來觀察Focal Loss計算結果(表一)。假設Modulating Factor是0且Weighting Factor是1,模型預測一筆標籤為1的資料之機率為0.9,其Cross Entropy Loss大約為0.0458;模型預測一筆標籤為1的資料之機率為0.1,其Cross Entropy Loss為1。假設Modulating Factor是2且Weighting Factor是1,模型預測一筆標籤為1的資料之機率為0.9,其Cross Entropy Loss大約為0.000458;模型預測一筆標籤為1的資料之機率為0.1,其Cross Entropy Loss為0.81。大家可以會發現,對於超級好分類的資料,就算預測機率不是這麼完美,即使資料量超級龐大,累積起來也不會對損失函數影響太多。反而那些很難預測的資料,錯一筆就會對損失函數貢獻非常多。
接下來,我們針對不同的Modulating Factor,畫出Focal Loss跟預測一筆標籤為1的機率之間關係,我們都設定Weighting Factor為1。可以發現當Modulating Factor數值越大,超級好分類的資料對於Focal Loss的貢獻就會急速衰減。
四、模擬
我們做一個簡單的模擬,2種資料集以及3種分類結果的情況下,Cross Entropy Loss、Balanced Cross Entropy Loss、以及Focal Loss計算結果。中心機率代表該類別資料的平均預測機率,Balanced Cross Entropy Loss跟Focal Loss都以各類別資料的反比作為Weighting Factor,Focal Loss的Modulating Factor為2。模擬結果如表二,紅色數字的結果呈現「在不平等分類的狀況下,Balanced Cross Entropy比較能反應出模型訓練結果」;綠色數字的結果呈現「對於錯得很離譜的少數類別資料,Focal Loss可以給以更大的懲罰」。
重點整理
1、如果資料集有出現不平衡資料集,可以考慮在損失函數加上Weighting Factor
2、如果資料集有出現Easy/Hard Examples,可以考慮在損失函數加上Modulating Factor
3、Focal Loss可以同時解決不平衡資料集、以及Easy/Hard Examples
參考資料
Lin, T. Y., Goyal, P., Girshick, R., He, K., and Dollar, P. (2018). Focal Loss for Dense Object Detection. IEEE Transactions on Pattern Analysis and Machine Intelligence, 42(2), pp. 318–327.
關於作者
Chia-Hao Li received the M.S. degree in computer science from Durham University, United Kingdom. He engages in computer algorithm, machine learning, and hardware/software codesign. He was former senior engineer in Mediatek, Taiwan. His currently research topic is the application of machine learning techniques for fault detection in the high-performance computing systems.