Les GAN tentent de reproduire une distribution de probabilité. Ils doivent donc utiliser des fonctions de perte qui reflètent la distance entre la distribution des données générées par le GAN et la distribution des données réelles.
Comment capturez-vous la différence entre deux distributions dans les fonctions de perte GAN ? Cette question est un domaine de recherche actif, et de nombreuses approches ont été proposées. Nous allons ici aborder deux fonctions de perte GAN courantes, toutes deux implémentées dans TF-GAN:
- perte minimax: fonction de perte utilisée dans l'article qui a présenté les GAN.
- Perte de Wasserstein: fonction de perte par défaut pour les estimateurs TF-GAN. Décrit pour la première fois dans un article de 2017.
TF-GAN implémente également de nombreuses autres fonctions de perte.
Une ou deux fonctions de perte ?
Un GAN peut avoir deux fonctions de perte: une pour l'entraînement du générateur et une pour l'entraînement du discriminateur. Comment deux fonctions de perte peuvent-elles fonctionner ensemble pour refléter une mesure de distance entre des distributions de probabilité ?
Dans les schémas de perte que nous allons examiner ici, les pertes du générateur et du discriminateur dérivent d'une seule mesure de la distance entre les distributions de probabilités. Toutefois, dans ces deux schémas, le générateur ne peut affecter qu'un seul terme de la mesure de distance: celui qui reflète la distribution des données factices. Par conséquent, lors de l'entraînement du générateur, nous supprimons l'autre terme, qui reflète la distribution des données réelles.
Les pertes du générateur et du discriminateur semblent différentes à la fin, même si elles dérivent d'une seule formule.
Perte minimax
Dans l'article qui a présenté les GAN, le générateur tente de minimiser la fonction suivante, tandis que le discriminateur tente de la maximiser:
Dans cette fonction:
D(x)
correspond à l'estimation du discriminateur de la probabilité que l'instance de données réelle x soit réelle.- Ex est la valeur attendue pour toutes les instances de données réelles.
G(z)
est la sortie du générateur lorsqu'il reçoit le bruit z.D(G(z))
est l'estimation du discriminateur de la probabilité qu'une instance factice soit réelle.- Ez est la valeur attendue pour toutes les entrées aléatoires du générateur (en fait, la valeur attendue pour toutes les instances factices générées G(z)).
- La formule découle de l'entropie croisée entre les distributions réelles et générées.
Le générateur ne peut pas affecter directement le terme log(D(x))
dans la fonction. Par conséquent, pour le générateur, minimiser la perte équivaut à minimiser log(1 -
D(G(z)))
.
Dans TF-GAN, consultez minimax_discriminator_loss et minimax_generator_loss pour une implémentation de cette fonction de perte.
Perte minimax modifiée
L'article original sur les GAN indique que la fonction de perte minimax ci-dessus peut entraîner un blocage du GAN aux premiers stades de l'entraînement du GAN lorsque la tâche du discriminateur est très facile. L'article suggère donc de modifier la perte du générateur afin que le générateur essaie de maximiser log D(G(z))
.
Dans TF-GAN, consultez modified_generator_loss pour une implémentation de cette modification.
Perte Wasserstein
Par défaut, TF-GAN utilise la perte de Wasserstein.
Cette fonction de perte dépend d'une modification du schéma GAN (appelé "Wasserstein GAN" ou "WGAN"), dans lequel le discriminateur ne classe pas réellement les instances. Pour chaque instance, il génère un nombre. Ce nombre ne doit pas être inférieur à un ni supérieur à 0.Nous ne pouvons donc pas utiliser 0, 5 comme seuil pour déterminer si une instance est réelle ou fausse. L'entraînement du discriminateur tente simplement de rendre la sortie plus importante pour les instances réelles que pour les instances factices.
Comme il ne peut pas vraiment faire la distinction entre le vrai et le faux, le discriminateur WGAN est en fait appelé "critique" au lieu de "discriminateur". Cette distinction est importante d'un point de vue théorique, mais à des fins pratiques, nous pouvons la considérer comme une reconnaissance du fait que les entrées des fonctions de perte ne doivent pas nécessairement être des probabilités.
Les fonctions de perte elles-mêmes sont trompeuses:
Perte critique: D(x) - D(G(z))
Le discriminateur tente de maximiser cette fonction. En d'autres termes, il tente de maximiser la différence entre sa sortie sur des instances réelles et sa sortie sur des instances factices.
Perte du générateur: D(G(z))
Le générateur tente de maximiser cette fonction. En d'autres termes, il essaie de maximiser la sortie du discriminateur pour ses instances factices.
Dans ces fonctions:
D(x)
correspond à la sortie du critique pour une instance réelle.G(z)
est la sortie du générateur lorsqu'il reçoit le bruit z.D(G(z))
correspond à la sortie du critique pour une fausse instance.- La sortie du critique D ne doit pas être comprise entre 1 et 0.
- Les formules dérivent de la distance de mouvement de la Terre entre les distributions réelles et générées.
Dans TF-GAN, consultez wasserstein_generator_loss et wasserstein_discriminator_loss pour les implémentations.
Conditions requises
La justification théorique du GAN Wasserstein (ou WGAN) exige que les poids de l'ensemble du GAN soient coupés afin qu'ils restent dans une plage limitée.
Avantages
Les GAN Wasserstein sont moins susceptibles de se bloquer que les GAN basés sur le minimax et évitent les problèmes liés à la disparition des gradients. La distance de déplacement de la Terre présente également l'avantage d'être une véritable métrique: une mesure de la distance dans un espace de distributions de probabilité. L'entropie croisée n'est pas une métrique en ce sens.