0
| 本文作者: AI研習(xí)社 | 2017-08-08 17:35 |
在我們之前的文章中,我們學(xué)習(xí)了如何構(gòu)造一個簡單的 GAN 來生成 MNIST 手寫圖片。對于圖像問題,卷積神經(jīng)網(wǎng)絡(luò)相比于簡單地全連接的神經(jīng)網(wǎng)絡(luò)更具優(yōu)勢,因此,我們這一節(jié)我們將繼續(xù)深入 GAN,通過融合卷積神經(jīng)網(wǎng)絡(luò)來對我們的 GAN 進(jìn)行改進(jìn),實(shí)現(xiàn)一個深度卷積 GAN。如果還沒有親手實(shí)踐過 GAN 的小伙伴可以先去學(xué)習(xí)一下上一篇專欄:生成對抗網(wǎng)絡(luò)(GAN)之 MNIST 數(shù)據(jù)生成。
專欄中的所有代碼都在我的 GitHub中,歡迎 star 與 fork。
本次代碼在 NELSONZHAO/zhihu/dcgan,里面包含了兩個文件:
dcgan_mnist:基于 MNIST 手寫數(shù)據(jù)集構(gòu)造深度卷積 GAN 模型
dcgan_cifar:基于 CIFAR 數(shù)據(jù)集構(gòu)造深度卷積 GAN 模型
本文主要以 MNIST 為例進(jìn)行介紹,兩者在本質(zhì)上沒有差別,只在細(xì)微的參數(shù)上有所調(diào)整。由于窮學(xué)生資源有限,沒有對模型增加迭代次數(shù),也沒有構(gòu)造更深的模型。并且也沒有選取像素很高的圖像,高像素非常消耗計(jì)算量。本節(jié)只是一個拋磚引玉的作用,讓大家了解 DCGAN 的結(jié)構(gòu),如果有資源的小伙伴可以自己去嘗試其他更清晰的圖片以及更深的結(jié)構(gòu),相信會取得很不錯的結(jié)果。
Python3
TensorFlow 1.0
Jupyter notebook
整個正文部分將包括以下部分:
- 數(shù)據(jù)加載
- 模型輸入
- Generator
- Discriminator
- Loss
- Optimizer
- 訓(xùn)練模型
- 可視化
數(shù)據(jù)加載
數(shù)據(jù)加載部分采用 TensorFlow 中的 input_data 接口來進(jìn)行加載。關(guān)于加載細(xì)節(jié)在前面的文章中已經(jīng)寫了很多次啦,相信看過我文章的小伙伴對 MNIST 加載也非常熟悉,這里不再贅述。
模型輸入
在 GAN 中,我們的輸入包括兩部分,一個是真實(shí)圖片,它將直接輸入給 discriminator 來獲得一個判別結(jié)果;另一個是隨機(jī)噪聲,隨機(jī)噪聲將作為 generator 來生成圖片的材料,generator 再將生成圖片傳遞給 discriminator 獲得一個判別結(jié)果。

上面的函數(shù)定義了輸入圖片與噪聲圖片兩個 tensor。
Generator
生成器接收一個噪聲信號,基于該信號生成一個圖片輸入給判別器。在上一篇專欄文章生成對抗網(wǎng)絡(luò)(GAN)之 MNIST 數(shù)據(jù)生成中,我們的生成器是一個全連接層的神經(jīng)網(wǎng)絡(luò),而本節(jié)我們將生成器改造為包含卷積結(jié)構(gòu)的網(wǎng)絡(luò),使其更加適合處理圖片輸入。整個生成器結(jié)構(gòu)如下:

我們采用了 transposed convolution 將我們的噪聲圖片轉(zhuǎn)換為了一個與輸入圖片具有相同 shape 的生成圖像。我們來看一下具體的實(shí)現(xiàn)代碼:

上面的代碼是整個生成器的實(shí)現(xiàn)細(xì)節(jié),里面包含了一些 trick,我們來一步步地看一下。
首先我們通過一個全連接層將輸入的噪聲圖像轉(zhuǎn)換成了一個 1 x 4*4*512 的結(jié)構(gòu),再將其 reshape 成一個 [batch_size, 4, 4, 512] 的形狀,至此我們其實(shí)完成了第一步的轉(zhuǎn)換。接下來我們使用了一個對加速收斂及提高卷積神經(jīng)網(wǎng)絡(luò)性能中非常有效的方法——加入 BN(batch normalization),它的思想是歸一化當(dāng)前層輸入,使它們的均值為 0 和方差為 1,類似于我們歸一化網(wǎng)絡(luò)輸入的方法。它的好處在于可以加速收斂,并且加入 BN 的卷積神經(jīng)網(wǎng)絡(luò)受權(quán)重初始化影響非常小,具有非常好的穩(wěn)定性,對于提升卷積性能有很好的效果。關(guān)于 batch normalization,我會在后面專欄中進(jìn)行一個詳細(xì)的介紹。
完成 BN 后,我們使用 Leaky ReLU 作為激活函數(shù),在上一篇專欄中我們已經(jīng)提過這個函數(shù),這里不再贅述。最后加入 dropout 正則化。剩下的 transposed convolution 結(jié)構(gòu)層與之類似,只不過在最后一層中,我們不采用 BN,直接采用 tanh 激活函數(shù)輸出生成的圖片。
在上面的 transposed convolution 中,很多小伙伴肯定會對每一層 size 的變化疑惑,在這里來講一下在 TensorFlow 中如何來計(jì)算每一層 feature map 的 size。首先,在卷積神經(jīng)網(wǎng)絡(luò)中,假如我們使用一個 k x k 的 filter 對 m x m x d 的圖片進(jìn)行卷積操作,strides 為 s,在 TensorFlow 中,當(dāng)我們設(shè)置 padding='same'時,卷積以后的每一個 feature map 的 height 和 width 為;當(dāng)設(shè)置 padding='valid'時,每一個 feature map 的 height 和 width 為
。那么反過來,如果我們想要進(jìn)行 transposed convolution 操作,比如將 7 x 7 的形狀變?yōu)?14 x 14,那么此時,我們可以設(shè)置 padding='same',strides=2 即可,與 filter 的 size 沒有關(guān)系;而如果將 4 x 4 變?yōu)?7 x 7 的話,當(dāng)設(shè)置 padding='valid'時,即
,此時 s=1,k=4 即可實(shí)現(xiàn)我們的目標(biāo)。
上面的代碼中我也標(biāo)注了每一步 shape 的變化。
Discriminator
Discriminator 接收一個圖片,輸出一個判別結(jié)果(概率)。其實(shí) Discriminator 完全可以看做一個包含卷積神經(jīng)網(wǎng)絡(luò)的圖片二分類器。結(jié)構(gòu)如下:

實(shí)現(xiàn)代碼如下:

上面代碼其實(shí)就是一個簡單的卷積神經(jīng)網(wǎng)絡(luò)圖像識別問題,最終返回 logits(用來計(jì)算 loss)與 outputs。這里沒有加入池化層的原因在于圖片本身經(jīng)過多層卷積以后已經(jīng)非常小了,并且我們加入了 batch normalization 加速了訓(xùn)練,并不需要通過 max pooling 來進(jìn)行特征提取加速訓(xùn)練。
Loss Function

Loss 部分分別計(jì)算 Generator 的 loss 與 Discriminator 的 loss,和之前一樣,我們加入 label smoothing 防止過擬合,增強(qiáng)泛化能力。
Optimizer
GAN 中實(shí)際包含了兩個神經(jīng)網(wǎng)絡(luò),因此對于這兩個神經(jīng)網(wǎng)絡(luò)要分開進(jìn)行優(yōu)化。代碼如下:

這里的 Optimizer 和我們之前不同,由于我們使用了 TensorFlow 中的 batch normalization 函數(shù),這個函數(shù)中有很多 trick 要注意。首先我們要知道,batch normalization 在訓(xùn)練階段與非訓(xùn)練階段的計(jì)算方式是有差別的,這也是為什么我們在使用 batch normalization 過程中需要指定 training 這個參數(shù)。上面使用 tf.control_dependencies 是為了保證在訓(xùn)練階段能夠一直更新 moving averages。具體參考 A Gentle Guide to Using Batch Normalization in Tensorflow - Rui Shu。
訓(xùn)練
到此為止,我們就完成了深度卷積 GAN 的構(gòu)造,接著我們可以對我們的 GAN 來進(jìn)行訓(xùn)練,并且定義一些輔助函數(shù)來可視化迭代的結(jié)果。代碼太長就不放上來了,可以直接去我的 GitHub 下載。
我這里只設(shè)置了 5 輪 epochs,每隔 100 個 batch 打印一次結(jié)果,每一行代表同一個 epoch 下的 25 張圖:

我們可以看出僅僅經(jīng)過了少部分的迭代就已經(jīng)生成非常清晰的手寫數(shù)字,并且訓(xùn)練速度是非常快的。

上面的圖是最后幾次迭代的結(jié)果。我們可以回顧一下上一篇的一個簡單的全連接層的 GAN,收斂速度明顯不如深度卷積 GAN。
到此為止,我們學(xué)習(xí)了一個深度卷積 GAN,并且看到相比于之前簡單的 GAN 來說,深度卷積 GAN 的性能更加優(yōu)秀。當(dāng)然除了 MNST 數(shù)據(jù)集以外,小伙伴兒們還可以嘗試很多其他圖片,比如我們之前用到過的 CIFAR 數(shù)據(jù)集,我在這里也實(shí)現(xiàn)了一個 CIFAR 數(shù)據(jù)集的圖片生成,我只選取了馬的圖片進(jìn)行訓(xùn)練:
剛開始訓(xùn)練時:

訓(xùn)練 50 個 epochs:

這里我只設(shè)置了 50 次迭代,可以看到最后已經(jīng)生成了非常明顯的馬的圖像,可見深度卷積 GAN 的優(yōu)勢。
我的 GitHub:NELSONZHAO (Nelson Zhao)
上面包含了我的專欄中所有的代碼實(shí)現(xiàn),歡迎 star,歡迎 fork。
雷峰網(wǎng)版權(quán)文章,未經(jīng)授權(quán)禁止轉(zhuǎn)載。詳情見轉(zhuǎn)載須知。