Las GAN intentan replicar una distribución de probabilidad. Por lo tanto, deben usar funciones de pérdida que reflejen la distancia entre la distribución de los datos generados por la GAN y la distribución de los datos reales.
¿Cómo capturas la diferencia entre dos distribuciones en las funciones de pérdida de GAN? Esta pregunta es un área de investigación activa y se propusieron muchos enfoques. Aquí abordaremos dos funciones de pérdida de GAN comunes, que se implementan en TF-GAN:
- Pérdida minimax: Es la función de pérdida que se usa en el artículo en el que se presentaron las GAN.
- Pérdida de Wasserstein: Es la función de pérdida predeterminada para los estimadores de TF-GAN. Se describió por primera vez en un artículo de 2017.
TF-GAN también implementa muchas otras funciones de pérdida.
¿Una función de pérdida o dos?
Una GAN puede tener dos funciones de pérdida: una para el entrenamiento del generador y otra para el entrenamiento del discriminador. ¿Cómo pueden dos funciones de pérdida trabajar juntas para reflejar una medida de distancia entre distribuciones de probabilidad?
En los esquemas de pérdidas que analizaremos aquí, las pérdidas del generador y del discriminador provienen de una sola medida de distancia entre las distribuciones de probabilidad. Sin embargo, en ambos esquemas, el generador solo puede afectar un término en la medida de distancia: el término que refleja la distribución de los datos falsos. Por lo tanto, durante el entrenamiento del generador, descartamos el otro término, que refleja la distribución de los datos reales.
Las pérdidas del generador y del discriminador se ven diferentes al final, aunque provienen de una sola fórmula.
Pérdida minimax
En el artículo en el que se presentaron las GAN, el generador intenta minimizar la siguiente función, mientras que el discriminador intenta maximizarla:
En esta función:
D(x)
es la estimación del discriminador de la probabilidad de que la instancia de datos reales x sea real.- Ex es el valor esperado en todas las instancias de datos reales.
G(z)
es el resultado del generador cuando se le proporciona el ruido z.D(G(z))
es la estimación del discriminador de la probabilidad de que una instancia falsa sea real.- Ez es el valor esperado sobre todas las entradas aleatorias del generador (en efecto, el valor esperado sobre todas las instancias falsas generadas G(z)).
- La fórmula se deriva de la entropía cruzada entre las distribuciones reales y generadas.
El generador no puede afectar directamente el término log(D(x))
en la función, por lo que, para el generador, minimizar la pérdida equivale a minimizar log(1 -
D(G(z)))
.
En TF-GAN, consulta minimax_discriminator_loss y minimax_generator_loss para ver una implementación de esta función de pérdida.
Pérdida minimax modificada
En el artículo original de GAN, se señala que la función de pérdida minimax anterior puede hacer que la GAN se bloquee en las primeras etapas del entrenamiento de GAN cuando el trabajo del discriminador es muy fácil. Por lo tanto, el artículo sugiere modificar la pérdida del generador para que este intente maximizar log D(G(z))
.
En TF-GAN, consulta modified_generator_loss para ver una implementación de esta modificación.
Pérdida de Wasserstein
De forma predeterminada, TF-GAN usa la pérdida de Wasserstein.
Esta función de pérdida depende de una modificación del esquema de GAN (denominado "Wasserstein GAN" o "WGAN") en el que el discriminador no clasifica instancias. Para cada instancia, muestra un número. Este número no tiene que ser menor que uno o mayor que 0, por lo que no podemos usar 0.5 como umbral para decidir si una instancia es real o falsa. El entrenamiento del discriminante solo intenta hacer que el resultado sea mayor para las instancias reales que para las falsas.
Debido a que no puede discriminar entre lo real y lo falso, el discriminador de WGAN en realidad se llama "crítico" en lugar de "discriminador". Esta distinción tiene importancia teórica, pero, a los efectos prácticos, podemos considerarla como un reconocimiento de que las entradas de las funciones de pérdida no tienen que ser probabilidades.
Las funciones de pérdida son engañosamente simples:
Pérdida de crítica: D(x) - D(G(z))
El discriminador intenta maximizar esta función. En otras palabras, intenta maximizar la diferencia entre su salida en instancias reales y su salida en instancias falsas.
Pérdida del generador: D(G(z))
El generador intenta maximizar esta función. En otras palabras, intenta maximizar el resultado del discriminador para sus instancias falsas.
En estas funciones:
D(x)
es el resultado del crítico para una instancia real.G(z)
es el resultado del generador cuando se le proporciona el ruido z.D(G(z))
es el resultado del crítico para una instancia falsa.- El resultado del crítico D no tiene que estar entre 1 y 0.
- Las fórmulas se derivan de la distancia de mover la tierra entre las distribuciones reales y generadas.
En TF-GAN, consulta wasserstein_generator_loss y wasserstein_discriminator_loss para ver las implementaciones.
Requisitos
La justificación teórica de la GAN de Wasserstein (o WGAN) requiere que las ponderaciones de toda la GAN se recorten para que permanezcan dentro de un rango limitado.
Beneficios
Las GAN de Wasserstein son menos vulnerables a quedarse atascadas que las GAN basadas en minimax y evitan los problemas con los gradientes que desaparecen. La distancia de movimiento de tierra también tiene la ventaja de ser una métrica real: una medida de la distancia en un espacio de distribuciones de probabilidad. La entropía cruzada no es una métrica en este sentido.