機器學習看得見Lesson 3 — 使用遮擋來看模型依靠什麼特徵做預測

施威銘研究室
10 min readSep 17, 2021

--

神經網路模型常常被認為是「黑盒子」,因為使用者通常不知道模型內部到底發生什麼事情。不過,其實有越來越多方法,可以讓我們一窺模型的運作方式,今天就要來介紹一個簡單的方法:遮擋(Occlusion)

一、遮擋的原理

如果神經網路依賴某些特徵做預測,只要我們將這些特徵遮住,讓神經網路看不到,預期神經網路就沒辦法做出正確的預測了。因此,我們試圖透過遮擋,來了解神經網路是如何做預測。

舉例來說,在一個手寫數字辨識的問題中,想要辨識數字9,關鍵在哪裡?是9這個數字長長尾巴嗎?應該不是吧,因為長長的尾巴跟數字1以及7很像,所以神經網路應該是依靠偏上方的圓圈圈來做預測吧。如果我們把數字9上方遮住,預期模型大概就會辨識不出這是不是數字9了。

圖一、數字9的辨識。遮住上方,會無法辨識是9還是1還是7。遮住下方,可能還猜到是數字9。

二、準備資料、建立模型

為了驗證遮擋的效果,我們需要準備資料並且訓練一個模型。此部分是參考旗標出版的「自學機器學習 — 上Kaggle接軌世界,成為資料科學家」第4章的範例程式碼。這個範例的資料集是Kaggle平台的手寫數字資料集

import copy
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import KFold
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation
train = pd.read_csv('/kaggle/input/digit-recognizer/train.csv')
train_x = train.drop(['label'], axis=1)
train_y = train['label']
test_x = pd.read_csv('/kaggle/input/digit-recognizer/test.csv')
kf = KFold(n_splits = 4, shuffle = True, random_state = 123)
tr_idx, va_idx = list(kf.split(train_x))[0]
tr_x, va_x = train_x.iloc[tr_idx], train_x.iloc[va_idx]
tr_y, va_y = train_y.iloc[tr_idx], train_y.iloc[va_idx]
tr_x = np.array(tr_x / 255.0)
va_x = np.array(va_x / 255.0)
tr_y = to_categorical(tr_y, 10)
va_y = to_categorical(va_y, 10)

接著,我們建立一個簡單的前饋神經網路(feed-forward neural network)。輸入層配合影像,建立784個神經元。隱藏層有128個神經元,搭配sigmoid activation function。輸出層配置10個神經元,使用softmax activation function。

損失函數使用Keras內建給多元分類問題使用的categorical_crossentropy,優化器使用adam,評價指標使用accuracy。

model = Sequential()
model.add(Dense(128, input_dim = tr_x.shape[1],
activation = 'sigmoid'))
model.add(Dense(10, activation = 'softmax'))
model.compile(
loss = 'categorical_crossentropy',
optimizer = 'adam',
metrics = ['accuracy'])

接著開始訓練模型。使用5個epoch就夠了,批次大小設定為100。

result = model.fit(tr_x, tr_y,
epochs = 5,
batch_size = 100,
validation_data = (va_x, va_y),
verbose = 1)

訓練結果,驗證資料的準確率可以達到0.9444,這樣已經足夠做遮擋的展示了。

Epoch 1/5
315/315 [==============================] - 2s 4ms/step - loss: 1.1632 - accuracy: 0.7247 - val_loss: 0.3844 - val_accuracy: 0.8994
Epoch 2/5
315/315 [==============================] - 1s 3ms/step - loss: 0.3391 - accuracy: 0.9123 - val_loss: 0.2889 - val_accuracy: 0.9183
Epoch 3/5
315/315 [==============================] - 1s 2ms/step - loss: 0.2714 - accuracy: 0.9223 - val_loss: 0.2454 - val_accuracy: 0.9294
Epoch 4/5
315/315 [==============================] - 1s 2ms/step - loss: 0.2245 - accuracy: 0.9377 - val_loss: 0.2162 - val_accuracy: 0.9376
Epoch 5/5
315/315 [==============================] - 1s 2ms/step - loss: 0.2000 - accuracy: 0.9442 - val_loss: 0.1951 - val_accuracy: 0.9444

我們先把模型對訓練資料的預測結果,以及對應的預測機率都存下來,等一下會用到。

result = model.predict(tr_x)
probability = [max(row) for row in result]
predict = [row.argmax() for row in result]

三、實作遮擋

為了要知道神經網路模型是依靠手寫數字影像的哪一部分做預測,我們打算用一個5 x 5的遮罩(mask),從影像最左上角開始,一路掃描到最右下角。因此程式中有定義image_y跟image_x變數,來決定現在掃描到影像的哪一個位置。

當目前的掃描到的影像位置是x1, y1,接著就會把5 x 5的遮罩套上去,在遮罩範圍裡的像素,全部都會變成0,藉此神經網路就沒有辦法用這幾個像素做預測。為了完成這個遮罩運算,我們在程式中額外建立2個變數mask_y跟mask_x,使用image_y、image_x、mask_y、mask_x這幾個變數,就可以計算出影像中哪幾個像素要設定成0。

得到了套用遮擋的影像後,接著就餵入前一節已經訓練好的模型,並獲得預測機率。我們用「沒有遮擋的影像預測機率」,減去「有遮擋的影像預測機率」,這個差值越大,就代表因為遮擋造成模型預測準確度下降越多,也就表示模型非常依賴這部分的像素值做預測。

最後,我們把每一個位置的機率下降幅度畫出來。越亮的位置代表預測機率下降越多,也就是模型越依賴此部分的像素值做預測。

image_id = 10
mask_size = 5
diff = np.zeros((28, 28))
for image_y in range(28):
for image_x in range(28):
image_2d = copy.deepcopy(tr_x[image_id].reshape((28, 28)))
for mask_y in range(mask_size):
for mask_x in range(mask_size):
position_y = image_y + (mask_y - 1)
position_x = image_x + (mask_x - 1)
if((position_y >= 0)and(position_y < 28)and
(position_x >= 0)and(position_x < 28)):
image_2d[position_y][position_x] = 0

image_1d = image_2d.reshape((1, 784))
result = model.predict(image_1d)
result = result[0][predict[image_id]]

diff[image_y][image_x] = probability[image_id] - result
print("Predict:", predict[image_id], "with probability:",
probability[image_id])
plt.imshow(tr_x[image_id].reshape((28, 28)),
interpolation = 'nearest')
plt.show()
plt.imshow(diff, interpolation = 'nearest')
plt.show()

圖二為實作結果,上圖是原始的數字9影像,下圖是不同位置的遮擋,對預測機率的影響。可以發現模型確實比較依賴數字9上方的圓圈來做預測,所以下圖的圓圈部分比較亮,跟我們的預期是一樣的。

圖二,實作遮擋於數字9的辨識結果

透過遮擋的方法,我們就可以知道模型是依賴什麼特徵做預測,本文是一個簡單的範例,讀者可以試試不同大小的遮罩、多個遮罩,甚至可以用隨機產生遮罩個數跟大小,來看看不同遮罩的配置,對模型預測能力的影響,藉此更進一步了解模型的運作機制。

關於作者

Chia-Hao Li received the M.S. degree in computer science from Durham University, United Kingdom. He engages in computer algorithm, machine learning, and hardware/software codesign. He was former senior engineer in Mediatek, Taiwan. His currently research topic is the application of machine learning techniques for fault detection in the high-performance computing systems.

--

--

施威銘研究室
施威銘研究室

Written by 施威銘研究室

致力開發AI領域的圖書、創客、教具,希望培養更多的AI人才。整合各種人才,投入創客產品的開發,推廣「實作學習」,希望實踐學以致用的理想。

No responses yet