Pytorch手把手實作-Generative Adversarial Network (GAN)

Tommy Huang
9 min readAug 21, 2021

--

Generative Adversarial Network(GAN)簡述

前言廢話免了,會進來看文章內容的只有四種人
1. 只想知道皮毛,GAN在幹什麼的
2. 想知道細節怎麼把GAN訓練起來
3. 收藏在我的最愛或是書籤當作有看過了
4. 上課上到一定要點點進來。

GAN屬於unsupervised learning。

白話一點,GAN是用來生成資料。

講難聽一點,GAN被廣泛用來造假的。(但也有正向的)

最近比較知名的影像轉換

AI界知名人士的小孩版本。(source: https://www.reddit.com/r/MachineLearning/comments/o843t5/d_types_of_machine_learning_papers/)

如果不認識我幫你們對應起來

我其實有找到其他人對應的圖,但我懶得放了。

下面的網址有用StyleGAN: 可以讓人變年輕微笑的範例。 https://www.reddit.com/r/MachineLearning/comments/o6wggh/r_finally_actual_real_images_editing_using/

這不是跟抖音內建功能一樣,可以換臉(卡通),可以換表情,可以自動上妝,這用到的技術就是GAN相關的,屏除到政治因素,我個人覺得抖音滿好玩的。

聽說這個螞蟻呀嘿下架了,我還沒玩到><

利用GAN技術讓老照片活起來,

Source: https://imgur.com/i284hKw

以上都是GAN應用最近比較有名的一些影片或是APP等簡單介紹。

正文開始...

GAN 生成對抗網路:顧名思義,就是有兩個網路架構,分別為「生成」(Generator)和「對抗」(Discriminator)

GAN的概念很簡單,我們可以用一部老電影來描述(中文:神鬼交鋒,英文: Catch me if you can,英文比較有感):

中文:神鬼交鋒,英文: Catch me if you can

一個造假者(李奧納多)和一個專家(湯姆漢克),造假者需要做假的東西(假支票)出來,讓專家去判斷真偽,透過專家的判斷造假者在不斷的增進自己的造假技術,直到專家無法有效的判斷真偽。

整個GAN運作的核心概念如下,李奧納多就是「生成器(Generator)」,湯姆漢克就是「對抗: 辨別器(Discriminator)」:

花樣看完了
實際上我們將GAN化成簡圖,如下

Generator (G) 和 Discriminator (D)

D要判斷「真」還是「假」

G生成的資料要呼嚨D。

從Random Vector(z,可以為均勻分布或是常態分布)丟入G生成出圖片,所以目的就是希望使得G(z)的機率分布接近D的機率分布。

GAN的核心想法

Discriminator: 希望D(x)真實資料被判給真實的機率期望值最大(接近1)

Discriminator: 希望D(G(z))假資料被判給真實的機率期望值最小(接近0)

Generator -> Discriminator: 因為要乎巄D,所以在Generator階段,希望D(G(z))假資料被判給真實的機率期望值最大(接近1)

Objective Function of GAN:

看到這邊應該很有感才對,不管是在公式或是算法上

實際上GAN的坑很多,光是Generator和Discriminator怎麼設計就是個坑了。

  1. 後面範例以DCGAN的模型要設計過Generator才有辦法Upsample到MNIST的大小(28*28)。
  2. Generator參數變化不要一次更新太大,通常可以更新幾次D後再更新G。 (MNIST範例很簡單,所以可以不用)
  3. Learning rate不要設定太大。 如果大家有看過其他人範例大部分都設定為0.0002,其實這樣的設定有文獻出處Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks

以上是很簡單的GAN理論(有錯請鞭小力一點,不要太兇)介紹。

Pytorch手把手進行DCGAN實作,以MNIST資料庫為例

這張圖的來源我忘了(應該是DCGAN的論文吧),但這些文章沒有營利,應該沒有觸法吧。

0. 先import 模組吧。

Generator

因為我的random vector(z)是採用 latents x 1 x 1 (latents代表z的維度數)

DCGAN是採用ConvTranspose2d進行上採樣的進行,也就是讓圖變大張。

MNIST圖片為28 x 28,一般上採樣通常是固定放大1倍。

1 x 1 → 上採樣 → 2 x 2 → 上採樣 → 4 x 4 → 上採樣 → 8 x 8 → 上採樣 → 16 x 16 → 上採樣 → 32 x 32

所以不會變成28 x 28。

所以我利用ConvTranspose2d的stride和pad的設計,讓上採樣可以非1倍放大,細節請看程式碼,我每一層輸出的大小有寫在備註。
1 x 1 → ConvTranspose2d → 2 x 2 → ConvTranspose2d → 3 x 3 → ConvTranspose2d → 6 x 6 → ConvTranspose2d → 7 x 7 → ConvTranspose2d → 14 x 14 → ConvTranspose2d → 28 x 28

Discriminator

這邊就沒什麼特別注意,就是建立一個分類CNN而已,所以我建立一個5層CNN+1層FC可以看下面Discriminator的定義。

我們先訂一些卷積模塊(CBR, CBLR, TCBR),然後依據上述建立「Generator」和「Discriminator」。

這邊開始我們宣告一些pytorch訓練需要的一些元件,例如:GPU的使用、「Generator」和「Discriminator」的optimizer、學習時候學習率的lr_scheduler和MNIST的dataloader等。

Generator的更新

Discriminator的更新

等下程式在執行,模型的Update (loss)需要符合上述的執行,

執行後的結果

不同epoch訓練出來生成的結果圖。

plt.plot(loss_g)
plt.plot(loss_d,'r')
plt.legend(['G','D'])
plt.show()
Generator 和Discriminator在每一次更新的loss變化。

收斂了 ,可以進行生成測試。

Generator測試

實驗一: random vector(z)是採用 latents產生範圍normal(0,1),大概範圍是-3~3之間,生成的圖片

實驗一

實驗二: random vector(z)是採用 latents產生範圍normal(0,1)*-10000,大概範圍是-30000~30000之間,生成的圖片

實驗二

實驗三: random vector(z)是採用 latents產生範圍normal(0,1)* 50000,大概範圍是-50000~50000之間,生成的圖片

實驗三

大家可以和Pytorch手把手實作-AutoEncoder這篇比較,這個random vector(z)在GAN好像比較不會影響結果,但可能是生成結構的關係,在圖片生成的過程中已經將input的random vector(z)正規化了,所以在生成的時候就不影響,但實際上是怎麼避掉這樣的影響我就沒有去深入研究。

相關內容可以看我課程的github,裡面有我寫好的ipynb可以用。

https://github.com/TommyHuang821/Pytorch_DL_Implement/blob/main/12_Pytorch_DCGAN.ipynb

--

--

Tommy Huang
Tommy Huang

Written by Tommy Huang

怕老了忘記這些吃飯的知識,開始寫文章記錄機器/深度學習相關內容。Medium現在有打賞功能(每篇文章最後面都有連結),如果覺得寫的文章不錯,也可以Donate給個Tipping吧。黃志勝 Chih-Sheng Huang (Tommy), mail: chih.sheng.huang821@gmail.com

No responses yet