如何在GAN中平衡生成器和鉴别器的性能? [英] How to balance the generator and the discriminator performances in a GAN?

查看:52
本文介绍了如何在GAN中平衡生成器和鉴别器的性能?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

这是我第一次与GAN合作,而且我面临着一个问题,那就是鉴别器一再超越表现器.我正在尝试从此文章,我正在查看此稍有不同的实现方式帮助我.

It's the first time I'm working with GANs and I am facing an issue regarding the Discriminator repeatedly outperforming the Generator. I am trying to reproduce the PA model from this article and I'm looking at this slightly different implementation to help me out.

我已经阅读了很多有关GAN的工作方式的论文,并且还遵循了一些教程来更好地理解它们.此外,我已经阅读了有关如何克服主要不稳定因素的文章,但找不到解决这种现象的方法.

I have read quite a lot of papers on how GANs work and also followed some tutorials to understand them better. Moreover, I've read articles on how to overcome the major instabilities, but I can't find a way to overcome this behavior.

在我的环境中,我正在使用 PyTorch BCELoss().遵循 DCGAN PyTorch教程之后,我正在使用以下训练循环:

In my environment, I'm using PyTorch and BCELoss(). Following the DCGAN PyTorch tutorial, I'm using the following training loop:

criterion = nn.BCELoss()
train_d = False
# Discriminator true
optim_d.zero_grad()
disc_train_real = target.to(device)
batch_size = disc_train_real.size(0)
label = torch.full((batch_size,), 1, device=device).cuda()
output_d = discriminator(disc_train_real).view(-1)
loss_d_real = criterion(output_d, label).cuda()
if lossT:
    loss_d_real *= 2
if loss_d_real.item() > 0.3:
    loss_d_real.backward()
    train_d = True
D_x = output_d.mean().item()
# Discriminator false
output_g = generator(image)
output_d = discriminator(output_g.detach()).view(-1)
label.fill_(0)
loss_d_fake = criterion(output_d, label).cuda()
D_G_z1 = output_d.mean().item()
if lossT:
    loss_d_fake *= 2
loss_d = loss_d_real + loss_d_fake
if loss_d_fake.item() > 0.3:
    loss_d_fake.backward()
    train_d = True
if train_d:
    optim_d.step()

# Generator
label.fill_(1)
output_d = discriminator(output_g).view(-1)
loss_g = criterion(output_d, label).cuda()
D_G_z2 = output_d.mean().item()
if lossT:
    loss_g *= 2

loss_g.backward()
optim_g.step()

,经过一段时间的解决,一切似乎都很好:

and, after a period of settlement, everything seems to work fine:

Epoch 1/5 - Step: 1900/9338  Loss G: 3.057388  Loss D: 0.214545  D(x): 0.940985  D(G(z)): 0.114064 / 0.114064
Time for the last step: 51.55 s    Epoch ETA: 01:04:13
Epoch 1/5 - Step: 2000/9338  Loss G: 2.984724  Loss D: 0.222931  D(x): 0.879338  D(G(z)): 0.159163 / 0.159163
Time for the last step: 52.68 s    Epoch ETA: 01:03:24
Epoch 1/5 - Step: 2100/9338  Loss G: 2.824713  Loss D: 0.241953  D(x): 0.905837  D(G(z)): 0.110231 / 0.110231
Time for the last step: 50.91 s    Epoch ETA: 01:02:29
Epoch 1/5 - Step: 2200/9338  Loss G: 2.807455  Loss D: 0.252808  D(x): 0.908131  D(G(z)): 0.218515 / 0.218515
Time for the last step: 51.72 s    Epoch ETA: 01:01:37
Epoch 1/5 - Step: 2300/9338  Loss G: 2.470529  Loss D: 0.569696  D(x): 0.620966  D(G(z)): 0.512615 / 0.350175
Time for the last step: 51.96 s    Epoch ETA: 01:00:46
Epoch 1/5 - Step: 2400/9338  Loss G: 2.148863  Loss D: 1.071563  D(x): 0.809529  D(G(z)): 0.114487 / 0.114487
Time for the last step: 51.59 s    Epoch ETA: 00:59:53
Epoch 1/5 - Step: 2500/9338  Loss G: 2.016863  Loss D: 0.904711  D(x): 0.621433  D(G(z)): 0.440721 / 0.435932
Time for the last step: 52.03 s    Epoch ETA: 00:59:02
Epoch 1/5 - Step: 2600/9338  Loss G: 2.495639  Loss D: 0.949308  D(x): 0.671085  D(G(z)): 0.557924 / 0.420826
Time for the last step: 52.66 s    Epoch ETA: 00:58:12
Epoch 1/5 - Step: 2700/9338  Loss G: 2.519842  Loss D: 0.798667  D(x): 0.775738  D(G(z)): 0.246357 / 0.265839
Time for the last step: 51.20 s    Epoch ETA: 00:57:19
Epoch 1/5 - Step: 2800/9338  Loss G: 2.545630  Loss D: 0.756449  D(x): 0.895455  D(G(z)): 0.403628 / 0.301851
Time for the last step: 51.88 s    Epoch ETA: 00:56:27
Epoch 1/5 - Step: 2900/9338  Loss G: 2.458109  Loss D: 0.653513  D(x): 0.820105  D(G(z)): 0.379199 / 0.103250
Time for the last step: 53.50 s    Epoch ETA: 00:55:39
Epoch 1/5 - Step: 3000/9338  Loss G: 2.030103  Loss D: 0.948208  D(x): 0.445385  D(G(z)): 0.303225 / 0.263652
Time for the last step: 51.57 s    Epoch ETA: 00:54:47
Epoch 1/5 - Step: 3100/9338  Loss G: 1.721604  Loss D: 0.949721  D(x): 0.365646  D(G(z)): 0.090072 / 0.232912
Time for the last step: 52.19 s    Epoch ETA: 00:53:55
Epoch 1/5 - Step: 3200/9338  Loss G: 1.438854  Loss D: 1.142182  D(x): 0.768163  D(G(z)): 0.321164 / 0.237878
Time for the last step: 50.79 s    Epoch ETA: 00:53:01
Epoch 1/5 - Step: 3300/9338  Loss G: 1.924418  Loss D: 0.923860  D(x): 0.729981  D(G(z)): 0.354812 / 0.318090
Time for the last step: 52.59 s    Epoch ETA: 00:52:11

,即,生成器上的梯度较高,并在一段时间后开始减小,与此同时,鉴别器上的梯度上升.至于损失,发生器下降而鉴别器上升.如果与本教程相比,我想这可以接受.

that is, the gradients on the Generator are higher and start to decrease after a while, and in the meanwhile the gradients on the Discriminator rise up. As for the losses, the Generator goes down while the Discriminator goes up. If compared to the tutorial, I guess this can be acceptable.

这是我的第一个问题:我注意到在教程中(通常)随着 D_G_z1 的增加, D_G_z2 的减少(反之亦然),而在我的示例中,这种情况发生的次数要少得多.只是巧合还是我做错了什么?

Here's my first question: I've noticed that on the tutorial (usually) as D_G_z1 rises, D_G_z2 decreases (and viceversa), while in my example this happens a lot less. Is it just a coincidence or am I doing something wrong?

鉴于此,我让培训过程继续进行,但现在我注意到了这一点:

Given that, I've let the training procedure go on, but now I'm noticing this:

Epoch 3/5 - Step: 1100/9338  Loss G: 4.071329  Loss D: 0.031608  D(x): 0.999969  D(G(z)): 0.024329 / 0.024329
Time for the last step: 51.41 s    Epoch ETA: 01:11:24
Epoch 3/5 - Step: 1200/9338  Loss G: 3.883331  Loss D: 0.036354  D(x): 0.999993  D(G(z)): 0.043874 / 0.043874
Time for the last step: 51.63 s    Epoch ETA: 01:10:29
Epoch 3/5 - Step: 1300/9338  Loss G: 3.468963  Loss D: 0.054542  D(x): 0.999972  D(G(z)): 0.050145 / 0.050145
Time for the last step: 52.47 s    Epoch ETA: 01:09:40
Epoch 3/5 - Step: 1400/9338  Loss G: 3.504971  Loss D: 0.053683  D(x): 0.999972  D(G(z)): 0.052180 / 0.052180
Time for the last step: 50.75 s    Epoch ETA: 01:08:41
Epoch 3/5 - Step: 1500/9338  Loss G: 3.437765  Loss D: 0.056286  D(x): 0.999941  D(G(z)): 0.058839 / 0.058839
Time for the last step: 52.20 s    Epoch ETA: 01:07:50
Epoch 3/5 - Step: 1600/9338  Loss G: 3.369209  Loss D: 0.062133  D(x): 0.955688  D(G(z)): 0.058773 / 0.058773
Time for the last step: 51.05 s    Epoch ETA: 01:06:54
Epoch 3/5 - Step: 1700/9338  Loss G: 3.290109  Loss D: 0.065704  D(x): 0.999975  D(G(z)): 0.056583 / 0.056583
Time for the last step: 51.27 s    Epoch ETA: 01:06:00
Epoch 3/5 - Step: 1800/9338  Loss G: 3.286248  Loss D: 0.067969  D(x): 0.993238  D(G(z)): 0.063815 / 0.063815
Time for the last step: 52.28 s    Epoch ETA: 01:05:09
Epoch 3/5 - Step: 1900/9338  Loss G: 3.263996  Loss D: 0.065335  D(x): 0.980270  D(G(z)): 0.037717 / 0.037717
Time for the last step: 51.59 s    Epoch ETA: 01:04:16
Epoch 3/5 - Step: 2000/9338  Loss G: 3.293503  Loss D: 0.065291  D(x): 0.999873  D(G(z)): 0.070188 / 0.070188
Time for the last step: 51.85 s    Epoch ETA: 01:03:25
Epoch 3/5 - Step: 2100/9338  Loss G: 3.184164  Loss D: 0.070931  D(x): 0.999971  D(G(z)): 0.059657 / 0.059657
Time for the last step: 52.14 s    Epoch ETA: 01:02:34
Epoch 3/5 - Step: 2200/9338  Loss G: 3.116310  Loss D: 0.080597  D(x): 0.999850  D(G(z)): 0.074931 / 0.074931
Time for the last step: 51.85 s    Epoch ETA: 01:01:42
Epoch 3/5 - Step: 2300/9338  Loss G: 3.142180  Loss D: 0.073999  D(x): 0.995546  D(G(z)): 0.054752 / 0.054752
Time for the last step: 51.76 s    Epoch ETA: 01:00:50
Epoch 3/5 - Step: 2400/9338  Loss G: 3.185711  Loss D: 0.072601  D(x): 0.999992  D(G(z)): 0.076053 / 0.076053
Time for the last step: 50.53 s    Epoch ETA: 00:59:54
Epoch 3/5 - Step: 2500/9338  Loss G: 3.027437  Loss D: 0.083906  D(x): 0.997390  D(G(z)): 0.082501 / 0.082501
Time for the last step: 52.06 s    Epoch ETA: 00:59:03
Epoch 3/5 - Step: 2600/9338  Loss G: 3.052374  Loss D: 0.085030  D(x): 0.999924  D(G(z)): 0.073295 / 0.073295
Time for the last step: 52.37 s    Epoch ETA: 00:58:12

不仅 D(x)再次增加并固定为几乎一个,而且 D_G_z1 D_G_z2 始终显示相同价值.此外,从损失的角度看,歧视者的表现似乎明显好于发生器.这种行为一直持续到下一个时期,直到下一个时期,直到训练结束.

not only D(x) has increased again and it's stuck to almost one, but also both D_G_z1 and D_G_z2 always show the same value. Moreover, looking at the losses it seems pretty clear that the Discriminator has outperformed the Generator. This behavior has gone on and on for the rest of the epoch and for all the next one, until the end of the training.

因此,我的第二个问题:这正常吗?如果没有,我在程序中做错了什么?如何获得更稳定的培训?

Hence my second question: is this normal? If not, what am I doing wrong within the procedure? How can I achieve a more stable training?

我尝试按照建议使用 MSELoss()训练网络,这是输出:

I've tried to train the network using the MSELoss() as suggested and this is the output:

Epoch 1/1 - Step: 100/9338  Loss G: 0.800785  Loss D: 0.404525  D(x): 0.844653  D(G(z)): 0.030439 / 0.016316
Time for the last step: 55.22 s    Epoch ETA: 01:25:01
Epoch 1/1 - Step: 200/9338  Loss G: 1.196659  Loss D: 0.014051  D(x): 0.999970  D(G(z)): 0.006543 / 0.006500
Time for the last step: 51.41 s    Epoch ETA: 01:21:11
Epoch 1/1 - Step: 300/9338  Loss G: 1.197319  Loss D: 0.000806  D(x): 0.999431  D(G(z)): 0.004821 / 0.004724
Time for the last step: 51.79 s    Epoch ETA: 01:19:32
Epoch 1/1 - Step: 400/9338  Loss G: 1.198960  Loss D: 0.000720  D(x): 0.999612  D(G(z)): 0.000000 / 0.000000
Time for the last step: 51.47 s    Epoch ETA: 01:18:09
Epoch 1/1 - Step: 500/9338  Loss G: 1.212810  Loss D: 0.000021  D(x): 0.999938  D(G(z)): 0.000000 / 0.000000
Time for the last step: 52.18 s    Epoch ETA: 01:17:11
Epoch 1/1 - Step: 600/9338  Loss G: 1.216168  Loss D: 0.000000  D(x): 0.999945  D(G(z)): 0.000000 / 0.000000
Time for the last step: 51.24 s    Epoch ETA: 01:16:02
Epoch 1/1 - Step: 700/9338  Loss G: 1.212301  Loss D: 0.000000  D(x): 0.999970  D(G(z)): 0.000000 / 0.000000
Time for the last step: 51.61 s    Epoch ETA: 01:15:02
Epoch 1/1 - Step: 800/9338  Loss G: 1.214397  Loss D: 0.000005  D(x): 0.999973  D(G(z)): 0.000000 / 0.000000
Time for the last step: 51.58 s    Epoch ETA: 01:14:04
Epoch 1/1 - Step: 900/9338  Loss G: 1.212016  Loss D: 0.000003  D(x): 0.999932  D(G(z)): 0.000000 / 0.000000
Time for the last step: 52.20 s    Epoch ETA: 01:13:13
Epoch 1/1 - Step: 1000/9338  Loss G: 1.215162  Loss D: 0.000000  D(x): 0.999988  D(G(z)): 0.000000 / 0.000000
Time for the last step: 52.28 s    Epoch ETA: 01:12:23
Epoch 1/1 - Step: 1100/9338  Loss G: 1.216291  Loss D: 0.000000  D(x): 0.999983  D(G(z)): 0.000000 / 0.000000
Time for the last step: 51.78 s    Epoch ETA: 01:11:28
Epoch 1/1 - Step: 1200/9338  Loss G: 1.215526  Loss D: 0.000000  D(x): 0.999978  D(G(z)): 0.000000 / 0.000000
Time for the last step: 51.88 s    Epoch ETA: 01:10:35

可以看出,情况变得更糟.此外,请再次阅读 EnhanceNet纸,第4.2.4节(训练)指出,所使用的对抗损失函数是 BCELoss(),因为我希望能解决我在 MSELoss()中遇到的消失梯度问题.

As can be seen, the situation gets even worse. Moreover, reading the EnhanceNet paper all over again, Section 4.2.4 (Adversarial Training) states that the adversarial loss function used is a BCELoss(), as I would expect to solve the vanishing gradients problem that I get with MSELoss().

推荐答案

解释GAN的损失有点荒唐,因为实际的损失值

Interpreting GAN Losses are a bit of a black art because the actual loss values

问题1:(根据我的经验),鉴别器/发电机优势之间摆动的频率主要基于以下几个因素:学习率和批量大小,这将影响传播的损失.所使用的特定损失指标将影响D& A的方差.G网络训练.EnhanceNet论文(用于基线)和本教程也使用均方误差损失-您正在使用二进制交叉熵损失,这将改变网络的收敛速度.我不是专家,所以这里有一个很好的链接到Rohan Varma的文章,该文章解释了损失函数之间的区别.奇怪的是,看看您更改丢失功能时网络的行为是否有所不同-试试看并更新问题?

Question 1: The frequency of swinging between a discriminator/generator dominance will vary based on a few factors primarily (in my experience): learning rates and batch sizes which will impact the propagated loss. The particular loss metrics used will impact variance in how the D & G networks train. The EnhanceNet paper (for baseline) and the tutorial use a Mean Squared Error loss too - you're using a Binary Cross Entropy loss which will change the rate at which the networks converge. I'm no expert so here's a pretty good link to Rohan Varma's article that explains the difference between loss functions. Would be curious to see if your network behaves differently when you change the loss function - try it and update the question?

问题2:随着时间的流逝,D损失和G损失都应该稳定在一个值上,但是很难判断他们是否已经在强大的绩效上趋于一致或是否趋于一致.由于诸如模式崩溃/梯度递减等原因,它们已经收敛(

Question 2: Over time both the D and G losses should settle to a value, however it's somewhat difficult to tell whether they've converged on strong performance or whether they've converged due to something like mode collapse/diminishing gradients (Jonathan Hui's explanation on problems in training GANs). The best way I've found is to actually inspect a cross section of the generated images and either visually inspect the output or use some kind of perceptual metrics (SSIM, PSNR, PIQ, etc.) across the generated image set.

一些其他有用的线索,可能对发现ans有用:

Some other useful leads that you might find useful in finding an ans:

这篇文章在解释GAN损失方面有两个相当不错的指标.

This post has a couple of reasonably good pointers on interpreting GAN Losses.

伊恩·古德费洛(Ian Goodfellow)的 NIPS2016教程对于如何平衡D&G训练.

Ian Goodfellow's NIPS2016 tutorial also has some solid ideas on how to balance D & G training.

这篇关于如何在GAN中平衡生成器和鉴别器的性能?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆