大家都知道,自從生成式對(duì)抗網(wǎng)絡(luò)(GAN)出現(xiàn)以來(lái),便在圖像處理方面有著廣泛的應(yīng)用。但還是有很多人對(duì)于GAN不是很了解,擔(dān)心由于沒(méi)有數(shù)學(xué)知識(shí)底蘊(yùn)而學(xué)不會(huì)GAN。在本文中,谷歌研究員Stefan Hosein提供了一份初學(xué)者入門(mén)GAN的教程,在這份教程中,即使你沒(méi)有擁有深厚的數(shù)學(xué)知識(shí),你也能夠了解什么是生成式對(duì)抗網(wǎng)絡(luò)(GAN)。
類比
理解GAN的一個(gè)最為簡(jiǎn)單的方法是通過(guò)一個(gè)簡(jiǎn)單的比喻:
假設(shè)有一家商店,店主要從顧客那里購(gòu)買某些種類的葡萄酒,然后再將這些葡萄酒銷售出去。
然而,有些可惡的顧客為了賺取金錢而出售假酒。在這種情況下,店主必須能夠區(qū)分假酒和正宗的葡萄酒。
你可以想象,在最初的時(shí)候,偽造者在試圖出售假酒時(shí)可能會(huì)犯很多錯(cuò)誤,并且店主很容易就會(huì)發(fā)現(xiàn)該酒不是正宗的葡萄酒。經(jīng)歷過(guò)這些失敗之后,偽造者會(huì)繼續(xù)嘗試使用不同的技術(shù)來(lái)模擬真正的葡萄酒,而有些方法最終會(huì)取得成功?,F(xiàn)在,偽造者知道某些技術(shù)已經(jīng)能夠躲過(guò)店主的檢查,那么他就可以開(kāi)始進(jìn)一步對(duì)基于這些技術(shù)的假酒進(jìn)行改善提升。
與此同時(shí),店主可能會(huì)從其他店主或葡萄酒專家那里得到一些反饋,說(shuō)明她所擁有的一些葡萄酒并不是原裝的。這意味著店主必須改進(jìn)她的判別方式,從而確定葡萄酒是偽造的還是正宗的。偽造者的目標(biāo)是制造出與正宗葡萄酒無(wú)法區(qū)分的葡萄酒,而店主的目標(biāo)是準(zhǔn)確地分辨葡萄酒是否是正宗的。
可以這樣說(shuō),這種循環(huán)往復(fù)的競(jìng)爭(zhēng)正是GAN背后的主要思想。
生成式對(duì)抗網(wǎng)絡(luò)的組成部分
通過(guò)上面的例子,我們可以提出一個(gè)GAN的體系結(jié)構(gòu)。
GAN中有兩個(gè)主要的組成部分:生成器和鑒別器。在上面我們所描述的例子中,店主被稱為鑒別器網(wǎng)絡(luò),通常是一個(gè)卷積神經(jīng)網(wǎng)絡(luò)(因?yàn)镚AN主要用于圖像任務(wù)),主要是分配圖像是真實(shí)的概率。
偽造者被稱為生成式網(wǎng)絡(luò),并且通常也是一個(gè)卷積神經(jīng)網(wǎng)絡(luò)(具有解卷積層,deconvolution layers)。該網(wǎng)絡(luò)接收一些噪聲向量并輸出一個(gè)圖像。當(dāng)對(duì)生成式網(wǎng)絡(luò)進(jìn)行訓(xùn)練時(shí),它會(huì)學(xué)習(xí)可以對(duì)圖像的哪些區(qū)域進(jìn)行改進(jìn)/更改,以便鑒別器將難以將其生成的圖像與真實(shí)圖像區(qū)分開(kāi)來(lái)。
生成式網(wǎng)絡(luò)不斷地生成與真實(shí)圖像更為接近的圖像,而與此同時(shí),鑒別式網(wǎng)絡(luò)則試圖確定真實(shí)圖像和假圖像之間的差異。最終的目標(biāo)就是建立一個(gè)生成式網(wǎng)絡(luò),它可以生成與真實(shí)圖像無(wú)法區(qū)分的圖像。
用Keras編寫(xiě)一個(gè)簡(jiǎn)單的生成式對(duì)抗網(wǎng)絡(luò)
現(xiàn)在,你已經(jīng)了解什么是GAN,以及它們的主要組成部分,那么現(xiàn)在我們可以開(kāi)始試著編寫(xiě)一個(gè)非常簡(jiǎn)單的代碼。你可以使用Keras,如果你不熟悉這個(gè)Python庫(kù)的話,則應(yīng)在繼續(xù)進(jìn)行操作之前閱讀本教程。本教程基于易于理解的GAN進(jìn)行開(kāi)發(fā)的。
首先,你需要做的第一件事是通過(guò)pip安裝以下軟件包:
- keras
- matplotlib
- tqdm
你將使用matplotlib繪圖,tensorflow作為Keras后端庫(kù)和tqdm,以顯示每個(gè)輪數(shù)(迭代)的花式進(jìn)度條。
下一步是創(chuàng)建一個(gè)Python腳本,在這個(gè)腳本中,你首先需要導(dǎo)入你將要使用的所有模塊和函數(shù)。在使用它們時(shí)將給出每個(gè)解釋。
importos
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from keras.layers import Input
from keras.models import Model, Sequential
from keras.layers.core import Dense, Dropout
from keras.layers.advanced_activations import LeakyReLU
from keras.datasets import mnist
from keras.optimizers import Adam
from keras import initializers
你現(xiàn)在需要設(shè)置一些變量:
# Let Keras know that we are using tensorflow as our backend engine
os.environ["KERAS_BACKEND"] = "tensorflow"
# To make sure that we can reproduce the experiment and get the same results
np.random.seed(10)
# The dimension of our random noise vector.
random_dim = 100
在開(kāi)始構(gòu)建鑒別器和生成器之前,你首先應(yīng)該收集數(shù)據(jù),并對(duì)其進(jìn)行預(yù)處理。你將會(huì)使用到常見(jiàn)的MNIST數(shù)據(jù)集,該數(shù)據(jù)集具有一組從0到9的單個(gè)數(shù)字圖像。
MINST數(shù)字樣本
def load_minst_data():
# load the data
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# normalize our inputs to be in the range[-1, 1]
x_train = (x_train.astype(np.float32) - 127.5)/127.5
# convert x_train with a shape of (60000, 28, 28) to (60000, 784) so we have
# 784 columns per row
x_train = x_train.reshape(60000, 784)
return (x_train, y_train, x_test, y_test)
需要注意的是,mnist.load_data()是Keras的一部分,這使得你可以輕松地將MNIST數(shù)據(jù)集導(dǎo)入至工作區(qū)域中。
現(xiàn)在,你可以開(kāi)始創(chuàng)建你的生成器和鑒別器網(wǎng)絡(luò)了。在這一過(guò)程中,你會(huì)使用到Adam優(yōu)化器。此外,你還需要?jiǎng)?chuàng)建一個(gè)帶有三個(gè)隱藏層的神經(jīng)網(wǎng)絡(luò),其激活函數(shù)為L(zhǎng)eaky Relu。對(duì)于鑒別器而言,你需要為其添加dropout層(dropout layers),以提高對(duì)未知圖像的魯棒性。
def get_optimizer():
return Adam(lr=0.0002, beta_1=0.5)
def get_generator(optimizer):
generator = Sequential()
generator.add(Dense(256, input_dim=random_dim, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
generator.add(LeakyReLU(0.2))
generator.add(Dense(512))
generator.add(LeakyReLU(0.2))
generator.add(Dense(1024))
generator.add(LeakyReLU(0.2))
generator.add(Dense(784, activation='tanh'))
generator.compile(loss='binary_crossentropy', optimizer=optimizer)
return generator
def get_discriminator(optimizer):
discriminator = Sequential()
discriminator.add(Dense(1024, input_dim=784, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(512))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(256))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(1, activation='sigmoid'))
discriminator.compile(loss='binary_crossentropy', optimizer=optimizer)
return discriminator
接下來(lái),則需要將發(fā)生器和鑒別器組合在一起!
def get_gan_network(discriminator, random_dim, generator, optimizer):
# We initially set trainable to False since we only want to train either the
# generator or discriminator at a time
discriminator.trainable = False
# gan input (noise) will be 100-dimensional vectors
gan_input = Input(shape=(random_dim,))
# the output of the generator (an image)
x = generator(gan_input)
# get the output of the discriminator (probability if the image is real or not)
gan_output = discriminator(x)
gan = Model(inputs=gan_input, outputs=gan_output)
gan.compile(loss='binary_crossentropy', optimizer=optimizer)
return gan
為了完整起見(jiàn),你還可以創(chuàng)建一個(gè)函數(shù),使其每訓(xùn)練20個(gè)輪數(shù)就對(duì)生成的圖像進(jìn)行1次保存。由于這不是本次課程的核心內(nèi)容,因此你不必完全理解該函數(shù)。
def plot_generated_images(epoch, generator, examples=100, dim=(10, 10), figsize=(10, 10)):
noise = np.random.normal(0, 1, size=[examples, random_dim])
generated_images = generator.predict(noise)
generated_images = generated_images.reshape(examples, 28, 28)
plt.figure(figsize=figsize)
for i in range(generated_images.shape[0]):
plt.subplot(dim[0], dim[1], i+1)
plt.imshow(generated_images[i], interpolation='nearest', cmap='gray_r')
plt.axis('off')
plt.tight_layout()
plt.savefig('gan_generated_image_epoch_%d.png' % epoch)
你現(xiàn)在已經(jīng)編碼了大部分網(wǎng)絡(luò),剩下的就是訓(xùn)練這個(gè)網(wǎng)絡(luò),并查看你創(chuàng)建的圖像。
def train(epochs=1, batch_size=128):
# Get the training and testing data
x_train, y_train, x_test, y_test = load_minst_data()
# Split the training data into batches of size 128
batch_count = x_train.shape[0] / batch_size
# Build our GAN netowrk
adam = get_optimizer()
generator = get_generator(adam)
discriminator = get_discriminator(adam)
gan = get_gan_network(discriminator, random_dim, generator, adam)
for e in xrange(1, epochs+1):
print '-'*15, 'Epoch %d' % e, '-'*15
for _ in tqdm(xrange(batch_count)):
# Get a random set of input noise and images
noise = np.random.normal(0, 1, size=[batch_size, random_dim])
image_batch = x_train[np.random.randint(0, x_train.shape[0], size=batch_size)]
# Generate fake MNIST images
generated_images = generator.predict(noise)
X = np.concatenate([image_batch, generated_images])
# Labels for generated and real data
y_dis = np.zeros(2*batch_size)
# One-sided label smoothing
y_dis[:batch_size] = 0.9
# Train discriminator
discriminator.trainable = True
discriminator.train_on_batch(X, y_dis)
# Train generator
noise = np.random.normal(0, 1, size=[batch_size, random_dim])
y_gen = np.ones(batch_size)
discriminator.trainable = False
gan.train_on_batch(noise, y_gen)
if e == 1 or e % 20 == 0:
plot_generated_images(e, generator)
if __name__ == '__main__':
train(400, 128)
在訓(xùn)練400個(gè)輪數(shù)后,你可以查看生成的圖像。在查看經(jīng)過(guò)1個(gè)輪數(shù)訓(xùn)練后而生成的圖像時(shí),你會(huì)發(fā)現(xiàn)它沒(méi)有任何真實(shí)結(jié)構(gòu),在查看經(jīng)過(guò)40個(gè)輪數(shù)訓(xùn)練后而生成的圖像時(shí),你會(huì)發(fā)現(xiàn)數(shù)字開(kāi)始成形,最后,在查看經(jīng)過(guò)400個(gè)輪數(shù)訓(xùn)練后而生成的圖像時(shí),你會(huì)發(fā)現(xiàn),除了一組數(shù)字難以辨識(shí)外,其余大多數(shù)數(shù)字都清晰可見(jiàn)。
訓(xùn)練1個(gè)輪數(shù)后的結(jié)果(上)| 訓(xùn)練40個(gè)輪數(shù)后的結(jié)果(中) | 訓(xùn)練400個(gè)輪數(shù)后的結(jié)果(下)
此代碼在CPU上運(yùn)行一次大約需要2分鐘,這也是我們選擇該代碼的主要原因。你可以嘗試進(jìn)行更多輪數(shù)的訓(xùn)練,并向生成器和鑒別器中添加更多數(shù)量(種類)的層。當(dāng)然,在僅使用CPU的前提下,采用更復(fù)雜和更深層的體系結(jié)構(gòu)時(shí),相應(yīng)的代碼運(yùn)行時(shí)間也會(huì)有所延長(zhǎng)。但也不要因此放棄嘗試。
至此,你已經(jīng)完成了全部的學(xué)習(xí)任務(wù),你以一種直觀的方式學(xué)習(xí)了生成式對(duì)抗網(wǎng)絡(luò)(GAN)的基礎(chǔ)知識(shí)!并且,你還在Keras庫(kù)的協(xié)助下實(shí)現(xiàn)了你的第一個(gè)模型。
-
GaN
+關(guān)注
關(guān)注
19文章
1909瀏覽量
72685
原文標(biāo)題:無(wú)需數(shù)學(xué)背景!谷歌研究員為你解密生成式對(duì)抗網(wǎng)絡(luò)
文章出處:【微信號(hào):AItists,微信公眾號(hào):人工智能學(xué)家】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
發(fā)布評(píng)論請(qǐng)先 登錄
相關(guān)推薦
評(píng)論