Les GAN tentent de répliquer une distribution de probabilité. Elles doivent donc utiliser des fonctions de perte qui reflètent la distance entre la distribution des données générées par le GAN et la distribution des données réelles.
Comment faites-vous la différence entre deux distributions dans les fonctions de perte du GAN ? Cette question est un domaine de recherche actif, et de nombreuses approches ont été proposées. Nous allons aborder deux fonctions de perte du GAN ici, toutes deux implémentées dans TF-GAN:
- Perte minimax: fonction de perte utilisée dans l'article présentant les GAN.
- Perte Wasserstein: fonction de perte par défaut pour les Estimators TF-GAN. Décrit pour la première fois dans un article de 2017.
TF-GAN implémente également de nombreuses autres fonctions de perte.
Une ou deux fonctions de perte ?
Un GAN peut avoir deux fonctions de perte: une pour l'entraînement du générateur et une pour l'entraînement du discriminateur. Comment deux fonctions de perte peuvent-elles fonctionner conjointement pour refléter une distance entre des distributions de probabilité ?
Dans les systèmes de perte que nous allons étudier ici, les pertes du générateur et du discriminateur proviennent d'une seule mesure de distance entre les distributions de probabilité. Cependant, dans ces deux schémas, le générateur ne peut affecter qu'un seul terme dans la mesure de distance, c'est-à-dire le terme qui reflète la distribution des données fictives. Ainsi, lors de l'entraînement du générateur, nous abandonnons l'autre terme, qui reflète la distribution des données réelles.
Les pertes de générateur et de discriminateur semblent différentes à la fin, même si elles proviennent d'une seule formule.
Perte minimale
Dans l'article présentant les GAN, le générateur tente de minimiser la fonction suivante, tandis que le discriminateur tente de l'optimiser:
Dans cette fonction:
D(x)
est une estimation de la probabilité que l'instance de données x réelle soit réelle.- Ex est la valeur attendue sur toutes les instances de données réelles.
G(z)
est la sortie du générateur lorsqu'il reçoit une valeur de z pour le bruit.D(G(z))
est l'estimation de la probabilité qu'une fausse instance soit réelle.- Ez est la valeur attendue sur toutes les entrées aléatoires du générateur (en effet, il s'agit de la valeur attendue sur toutes les fausses instances générées, G(z)).
- La formule est dérivée de l'entropie croisée entre les distributions réelles et générées.
Le générateur ne peut pas directement affecter le terme log(D(x))
dans la fonction. Par conséquent, minimiser le fait de minimiser la perte revient à réduire le log(1 -
D(G(z)))
.
Dans TF-GAN, consultez minimax_discriminator_loss and minimax_generator_loss pour obtenir une implémentation de cette fonction de perte.
Perte minimax modifiée
L'article d'origine sur le GAN indique que la fonction de perte minimax ci-dessus peut entraîner le blocage du GAN au tout début de l'entraînement du GAN lorsque la tâche du discriminateur est très facile. L'article suggère donc de modifier la perte du générateur pour qu'il tente de maximiser log D(G(z))
.
Dans TF-GAN, reportez-vous à la section Modified_generator_loss pour une implémentation de cette modification.
Perte de Wasserstein
Par défaut, TF-GAN utilise la perte Wasserstein.
Cette fonction de perte dépend d'une modification du schéma du GAN (nommé "Wasserstein GAN" ou "WGAN") dans lequel le discriminateur ne classe pas les instances. Un nombre est généré pour chaque instance. Ce nombre ne doit pas nécessairement être inférieur à 1 ou supérieur à 0.Nous ne pouvons donc pas utiliser 0, 5 comme seuil pour décider si une instance est réelle ou fausse. L'entraînement des discriminateurs tente simplement d'augmenter la sortie pour les instances réelles plutôt que pour les instances factices.
Étant donné qu'il ne peut pas vraiment faire la distinction entre le vrai et le faux, le discriminateur du WGAN est en réalité appelé un "critique" au lieu d'un "discriminateur". Cette distinction a une importance théorique, mais dans la pratique, nous pouvons la considérer comme le fait que les entrées des fonctions de perte ne doivent pas nécessairement être des probabilités.
Les fonctions de perte elles-mêmes sont d'une simplicité trompeuse:
Perte de critiques:D(x) - D(G(z))
Le discriminateur tente de maximiser cette fonction. En d'autres termes, il tente de maximiser la différence entre sa sortie sur des instances réelles et sa sortie sur de fausses instances.
Perte de générateur : D(G(z))
Le générateur tente de maximiser cette fonction. En d'autres termes, il tente d'optimiser le résultat du discriminateur pour ses fausses instances.
Dans ces fonctions:
D(x)
est le résultat du critique d'une instance réelle.G(z)
est la sortie du générateur lorsqu'il reçoit une valeur de z pour le bruit.D(G(z))
est la sortie du critique d'une fausse instance.- Il n'est pas nécessaire que le résultat du critique D soit compris entre 1 et 0.
- Les formules sont dérivées de la distance de déplacement de la Terre entre les distributions réelles et générées.
Dans TF-GAN, consultez wasserstein_generator_loss et wasserstein_discriminator_loss pour obtenir des implémentations.
Conditions requises
La justification théorique du GAN de Wasserstein (ou WGAN) nécessite que les pondérations dans le GAN soient tronquées de sorte qu'elles restent dans une plage limitée.
Avantages
Les GAN de Wasserstein sont moins vulnérables aux blocages que les GAN à minimax et évitent les problèmes de disparition des gradients. La distance d'accès à la Terre présente l'avantage d'être une véritable métrique: une mesure de distance dans un espace de distributions de probabilité. L'entropie croisée n'est pas une métrique à cet égard.