Pytorch手把手實作-Generative Adversarial Network (GAN)
Generative Adversarial Network(GAN)簡述
前言廢話免了,會進來看文章內容的只有四種人
1. 只想知道皮毛,GAN在幹什麼的
2. 想知道細節怎麼把GAN訓練起來
3. 收藏在我的最愛或是書籤當作有看過了
4. 上課上到一定要點點進來。
GAN屬於unsupervised learning。
白話一點,GAN是用來生成資料。
講難聽一點,GAN被廣泛用來造假的。(但也有正向的)
最近比較知名的影像轉換
如果不認識我幫你們對應起來
下面的網址有用StyleGAN: 可以讓人變年輕微笑的範例。 https://www.reddit.com/r/MachineLearning/comments/o6wggh/r_finally_actual_real_images_editing_using/
這不是跟抖音內建功能一樣,可以換臉(卡通),可以換表情,可以自動上妝,這用到的技術就是GAN相關的,屏除到政治因素,我個人覺得抖音滿好玩的。
利用GAN技術讓老照片活起來,
以上都是GAN應用最近比較有名的一些影片或是APP等簡單介紹。
正文開始...
GAN 生成對抗網路:顧名思義,就是有兩個網路架構,分別為「生成」(Generator)和「對抗」(Discriminator)
GAN的概念很簡單,我們可以用一部老電影來描述(中文:神鬼交鋒,英文: Catch me if you can,英文比較有感):
一個造假者(李奧納多)和一個專家(湯姆漢克),造假者需要做假的東西(假支票)出來,讓專家去判斷真偽,透過專家的判斷造假者在不斷的增進自己的造假技術,直到專家無法有效的判斷真偽。
整個GAN運作的核心概念如下,李奧納多就是「生成器(Generator)」,湯姆漢克就是「對抗: 辨別器(Discriminator)」:
花樣看完了
實際上我們將GAN化成簡圖,如下
Generator (G) 和 Discriminator (D)
D要判斷「真」還是「假」
G生成的資料要呼嚨D。
從Random Vector(z,可以為均勻分布或是常態分布)丟入G生成出圖片,所以目的就是希望使得G(z)的機率分布接近D的機率分布。
Discriminator: 希望D(x)真實資料被判給真實的機率期望值最大(接近1)
Discriminator: 希望D(G(z))假資料被判給真實的機率期望值最小(接近0)
Generator -> Discriminator: 因為要乎巄D,所以在Generator階段,希望D(G(z))假資料被判給真實的機率期望值最大(接近1)
Objective Function of GAN:
看到這邊應該很有感才對,不管是在公式或是算法上
實際上GAN的坑很多,光是Generator和Discriminator怎麼設計就是個坑了。
- 後面範例以DCGAN的模型要設計過Generator才有辦法Upsample到MNIST的大小(28*28)。
- Generator參數變化不要一次更新太大,通常可以更新幾次D後再更新G。 (MNIST範例很簡單,所以可以不用)
- Learning rate不要設定太大。 如果大家有看過其他人範例大部分都設定為0.0002,其實這樣的設定有文獻出處Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks
以上是很簡單的GAN理論(有錯請鞭小力一點,不要太兇)介紹。
Pytorch手把手進行DCGAN實作,以MNIST資料庫為例
0. 先import 模組吧。
Generator
因為我的random vector(z)是採用 latents x 1 x 1 (latents代表z的維度數)
DCGAN是採用ConvTranspose2d進行上採樣的進行,也就是讓圖變大張。
MNIST圖片為28 x 28,一般上採樣通常是固定放大1倍。
1 x 1 → 上採樣 → 2 x 2 → 上採樣 → 4 x 4 → 上採樣 → 8 x 8 → 上採樣 → 16 x 16 → 上採樣 → 32 x 32
所以不會變成28 x 28。
所以我利用ConvTranspose2d的stride和pad的設計,讓上採樣可以非1倍放大,細節請看程式碼,我每一層輸出的大小有寫在備註。
1 x 1 → ConvTranspose2d → 2 x 2 → ConvTranspose2d → 3 x 3 → ConvTranspose2d → 6 x 6 → ConvTranspose2d → 7 x 7 → ConvTranspose2d → 14 x 14 → ConvTranspose2d → 28 x 28
Discriminator
這邊就沒什麼特別注意,就是建立一個分類CNN而已,所以我建立一個5層CNN+1層FC可以看下面Discriminator的定義。
我們先訂一些卷積模塊(CBR, CBLR, TCBR),然後依據上述建立「Generator」和「Discriminator」。
這邊開始我們宣告一些pytorch訓練需要的一些元件,例如:GPU的使用、「Generator」和「Discriminator」的optimizer、學習時候學習率的lr_scheduler和MNIST的dataloader等。
Generator的更新
Discriminator的更新
等下程式在執行,模型的Update (loss)需要符合上述的執行,
執行後的結果
不同epoch訓練出來生成的結果圖。
plt.plot(loss_g)
plt.plot(loss_d,'r')
plt.legend(['G','D'])
plt.show()
收斂了 ,可以進行生成測試。
Generator測試
實驗一: random vector(z)是採用 latents產生範圍normal(0,1),大概範圍是-3~3之間,生成的圖片
實驗二: random vector(z)是採用 latents產生範圍normal(0,1)*-10000,大概範圍是-30000~30000之間,生成的圖片
實驗三: random vector(z)是採用 latents產生範圍normal(0,1)* 50000,大概範圍是-50000~50000之間,生成的圖片
大家可以和Pytorch手把手實作-AutoEncoder這篇比較,這個random vector(z)在GAN好像比較不會影響結果,但可能是生成結構的關係,在圖片生成的過程中已經將input的random vector(z)正規化了,所以在生成的時候就不影響,但實際上是怎麼避掉這樣的影響我就沒有去深入研究。
相關內容可以看我課程的github,裡面有我寫好的ipynb可以用。
https://github.com/TommyHuang821/Pytorch_DL_Implement/blob/main/12_Pytorch_DCGAN.ipynb