GAN论文解读
论文算法解读
原文地址https://arxiv.org/abs/1406.2661
干!GAN(Generative Adversarial Networks)从算法上来说GAN并没有特别难得网络结构,整体结构就分为两个部分一个是生成网络,一个是鉴别网络。生成网络输入一个随机噪声,根据噪声生成一张图片给鉴别器鉴别,鉴别器需要鉴别该图片是生成网络生成的还是真的来自于真实样本。这两个网络分开迭代训练。 首先初始化鉴别器和生成器的网络参数,训练鉴别器时,生成器参数静止,训练生成器时鉴别网络参数静止不变。如此交替往复,直到生成器生成的图片鉴别器鉴别其归属于生成器生成的概率是1/2。为什么是1/2?因为鉴别器对真实样本和生成的样本已经搞混淆了,认为他们是真实或者不是真实的概率都是相同的。
算法中生成器和鉴别器对应的就是一个多层感知机加入一些激活层。 算法中的对抗意思,就是生成器和鉴别器在对抗。算法中最精髓思想就是如何控制这两个网络进行对抗。而控制这一个过程的,我认为来自于其损失函数的设计。
如上图所示,鉴别器的损失函数更新方向在朝着尽可能的区分出真实样本和生成样本地方迭代,如何理解这一个过程呢,可以用极值法来试一下。 假设鉴别器足够智能完全可以区分出真实样本和生成样本呢。logD(x)恒等于0(因为D(x)恒等于1),而log(1-D(G(z)))也是恒等于0(因为D(G(z))恒等于0),最终整体的值就是0。 假设鉴别器鉴别不出真实样本和生成样本呢那么等式最终值是log(1/2)+log(1/2)是小于0的。 所以上面算法中说Updata the discriminator by ascending 网络朝着整体值的最大方向更新。 同理下面生成器的更新逻辑如下 假设鉴别器已经训练好了,现在来看log(1-D(G(z(i))))这个函数的更新特点,依然是极值法来考虑,假如鉴别器很轻松的鉴别出生成器生成的图是假的,那么log(1-D(G(z(i))))恒等于0。又假如鉴别器已经把生成的图都当做真的了,那么函数趋近于负无穷。所以生成器的网络朝着函数的极小值方向更新。Update the generator by descending。
代码解读
https://github.com/eriklindernoren/PyTorch-GAN.git 上面代码库中涵盖了很多GAN的变种,应用于不同的场合, 只解读原生的GAN网络。 https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/gan/gan.py 代码中生成鉴别器和生成器的逻辑就不赘述了,就是两个感知机,主要看两个网络是如何配合的。 这个代码中,实现者是先更新生成器的网络参数,然后再更新鉴别器,这个无所谓,只是初始时候谁先更新,不影响后面的迭代逻辑。
# Loss measures generator's ability to fool the discriminator
g_loss = adversarial_loss(discriminator(gen_imgs), valid)
g_loss.backward()
valid都是1,discriminator(gen_imgs)鉴别刚才生成的图像是0还是1,然后和假设的都是1进行交叉熵损失,直到在当前鉴别器下,生成器生成的图像经过鉴别尽可能都是1。 下面是鉴别器的损失函数设计
real_loss = adversarial_loss(discriminator(real_imgs), valid)
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
d_loss = (real_loss + fake_loss) / 2
生成器生成的图不变了,鉴别器分别对生成图鉴别和真实图鉴别,综合考虑这两个分类的总体损失。 不断往复迭代。 效果,黄色都是生成器生成的。
我总觉得GAN和最近的AIGC扩散模型效果应该差不多,回来详细讨论下二者的差异。
文档信息
- 本文作者:Tengyu Liu
- 本文链接:https://www.liutengyu.top/2024/04/20/GAN/
- 版权声明:自由转载-非商用-非衍生-保持署名(创意共享3.0许可证)