從理論到實例,看 GAN 是怎麼訓練出來的
本文取材自旗標科技出版的「GAN 對抗式生成網路」一書,圖文版權均歸該書所有。
GAN(Generative Adversarial Network)是 2014 年由當時還在蒙特利爾大學攻讀博士的 Ian Goodfellow 所發明,中文譯名為對抗式生成網路,代表它是一種對抗式的生成網路,不過也有人直譯為生成對抗網路。
這種技術讓電腦結合兩組不同的神經網路來生成擬真資料,而不是只用一組。GAN 的傑出表現與多功能性令人刮目相看,像是生成幾可亂真的假圖片、將隨意塗鴉轉換成擬真的影像、把在影片中奔跑的馬改成斑馬等等。這些都不需要使用大量精心標記過的訓練資料就可以輕易達成,由此可見GAN 的威力有多強大。
GAN 能生成多驚人的資料,看看真實的樣本就知道了,例如下圖中的人像合成。2014 年GAN 剛誕生時,機器頂多只能產生一張模糊的面孔,這在當時已被認為是突破性的成功了。到了2017 年,電腦合成出的虛擬人像,已經和高解析度的真人照片無異,GAN 只花了三年就辦到了。
圖片來源: “The Malicious Use of Artificial Intelligence: Forecasting, Pre vention, and Mitigation,” by Miles Brundage et al., 2018, https://arxiv.org/abs/1802.07228、及 “A Style-Based Generator Architecture for Generative Adversarial Networks,” by Tero Karras, Samuli Laine, Timo Aila, 2019, https://arxiv.org/abs/1812.04948。
GAN 長什麼樣子?
GAN 是由兩組神經網路模型所組成:生成器(Generator)被訓練用來生成假資料,鑑別器(Discriminator)則被訓練如何辨別資料真偽。
Generative Adversarial Network 中的 「Generative(生成)」一詞,指出了模型的終極目標:生成新資料。GAN 能學會生成怎樣的資料,取決於採用的訓練集。例如,若想讓 GAN 合成出有達文西風格的影像,就得在訓練集裡加入大量達文西的作品。
「Adversarial(對抗)」一詞,是指 GAN 內部的生成器與鑑別器兩模型之間,會不斷相互競爭:
- 生成器的目標,是創造出與訓練集內真樣本很像的假樣本。例如:偽造出神似達文西的畫作。若從實務面來看,則是要讓鑑別器將假樣本誤判為真樣本。
- 鑑別器的目標,則是要能判定真假,也就是要能正確分辨出生成器的假樣本與訓練集的真樣本。鑑別器就相當於藝術鑑賞家,負責鑑定達文西畫作的真假。
這兩組神經網路會不斷鬥智以求擊敗對方:生成器偽造的手段愈高明,鑑別器判斷真假的眼光就要越犀利。反之亦然,因此不斷對抗的結果,生成器所生成的假樣本就會越來越像真樣本,最後可能連我們人類也難以分辨真假。
GAN 的程式如何運作?
再以程式設計的角度來看,生成器會盡可能生成和真品非常相似的假樣本。一開始訓練時,可能只會生出亂七八糟的東西,但它會借由鑑別器的回饋(是否被認為是真樣本),一步步學習如何生成更逼真的假樣本。
鑑別器則負責分辨某樣本是真(來自訓練集)或假(來自生成器),然後將結果回饋給生成器。因此當鑑別器把假樣本當成真樣本時,生成器就知道自己進步了。反之,當鑑別器識破生成器的假樣本時,生成器也會知道自己必須再改進(調整自己神經網路中的參數)。
鑑別器本身也會不斷進步,就跟普通的分類器一樣,藉著比較其預測結果(真或假)與實際答案(真或假)的差異來學習。因此,當生成器產出的資料越來越逼真時,鑑別器分辨真偽的能力也會跟著越高明,兩組神經網路會不斷競爭(對抗),因此都會持續進步。下表整理了 GAN 兩組網路的重點:
GAN 的訓練過程
我們可將 GAN 的基本訓練過程,以下面的演算法來表示:(G 代表生成器、D 代表鑑別器)
For 每個訓練迭代 do 步驟 1:訓練鑑別器
a. 隨機取一小批真樣本:x
b. 取一小批隨機雜訊向量 z 生成假樣本:G(z)=x*
c. 計算 D(x) 與 D(x*) 的分類損失,再用反向傳播法,
根據總誤差調整參數 θ(D),以將分類損失最小化。 步驟 2:訓練生成器
a. 取一小批隨機雜訊向量 z 生成假樣本:G(z)=x*
b. 只計算 D(x*) 的分類損失,再用反向傳播法,
根據總誤差調整參數 θ(G),以將分類損失最大化。
注意!這裡只會調整 θ(G) 而不會調整 θ(D),
θ(D) 只有在步驟 1 才做調整。End for
在訓練鑑別器 (步驟1)時,生成器的參數會保持不變;同樣的,在訓練生成器(步驟2)時,鑑別器的參數也會保持不變。之所以要限制 D 和 G 在訓練時只能調整自身的參數,是為了隔離該網路之外其他參數所造成的影響,如此可確保 D 和 G 都能得到明確的回饋訊號,進而做出相應的調整,不會因另一方的調整而受影響。我們可把它想像成兩名參賽者輪流出招。
GAN 的詳細運作流程圖
假設我們要訓練 GAN 來生成幾可亂真的 MNIST 手寫數字圖片,那麼可用下面的 GAN 架構:
GAN 的訓練分成兩個步驟:訓練鑑別器和訓練生成器,整個訓練過程就是不斷重複這兩個步驟。前面演算法中的 a、b、c 子步驟也分別標示在圖中了,請自行對照觀看。
實作 GAN 來生成 MNIST 圖片
這裡我們使用 Keras 來撰寫所需的 GAN 模型,訓練資料集也是直接用 Keras 所內建的 MNIST 資料集。
不過為求簡單易懂,我們使用 2 個 Dense 層來建構生成器和鑑別器,因此最後生成的圖片並沒有很完美。程式在訓練過程中會不斷將生成的圖片顯示出來,因此我們可以看到生成器隨著訓練而逐漸進化的過程:最初只是隨機雜訊,後來越來越有手寫數字的樣子。
限於篇幅無法詳細列出相關程式碼,有興趣的讀者可直接連到我們在 Colab 上的筆記本觀看程式碼,並實際執行看看結果如何。另外也可再多加一些神經層、或做各種不同的測試。若對程式碼有不了解的地方,也可參閱「GAN 對抗式生成網路」一書的第 3 章,那裡有非常詳細而完整的解說。
當然,這只是最基本款的 GAN 模型,其他還有更多功能強大的 GAN 模型,未來有機會我們再撰文和大家分享。感謝大家的耐心閱讀,也歡迎大家一起加入 GAN 的學習行列。