GAN は確率分布を再現しようとします。したがって、GAN によって生成されたデータの分布と実際のデータの分布の距離を反映する損失関数を使用する必要があります。
GAN 損失関数で 2 つの分布の違いを捉えるにはどうすればよいですか?この問題は現在研究が進められている分野であり、多くのアプローチが提案されています。ここでは、TF-GAN で実装されている 2 つの一般的な GAN 損失関数について説明します。
- ミニマックス損失: GAN を紹介した論文で使用されている損失関数。
- Wasserstein 損失: TF-GAN 推定子のデフォルトの損失関数。2017 年の論文で初めて説明されました。
TF-GAN には、他にも多くの損失関数が実装されています。
損失関数は 1 つと 2 つどちらがいいですか?
GAN には、ジェネレータのトレーニング用と弁別子のトレーニング用の 2 つの損失関数を設定できます。2 つの損失関数を組み合わせて、確率分布間の距離測定を反映するにはどうすればよいですか?
ここで説明する損失スキームでは、生成元と識別子の損失は、確率分布間の距離の単一の測定値から導出されます。ただし、どちらのスキーマでも、生成ツールは距離測定の 1 つの項(偽データの分布を反映する項)にのみ影響できます。したがって、生成ツールのトレーニング中に、実際のデータの分布を反映するもう一方の項を削除します。
生成器と識別器の損失は、1 つの式から導き出されるにもかかわらず、最終的には異なります。
ミニマックス損失
GAN を紹介した論文では、ジェネレータは次の関数を最小限に抑えようとし、識別子は最大化しようとします。
この関数では、次の処理を行います。
D(x)
は、実際のデータ インスタンス x が本物である確率に対するディスкриминレータの推定値です。- Ex は、すべての実際のデータ インスタンスの期待値です。
G(z)
は、ノイズ z が指定された場合の生成元の出力です。D(G(z))
は、偽のインスタンスが本物である確率に対するディスкриминレータの推定値です。- Ez は、生成ツールへのすべてのランダム入力の期待値(実際には、生成されたすべての偽のインスタンス G(z) の期待値)です。
- この式は、実際の分布と生成された分布のクロス エントロピーから導出されます。
生成ツールは関数の log(D(x))
項に直接影響できないため、生成ツールにとって損失の最小化は log(1 -
D(G(z)))
の最小化と同等です。
TF-GAN でこの損失関数を実装するには、minimax_discriminator_loss と minimax_generator_loss をご覧ください。
修正ミニマックス損失
元の GAN の論文では、上記のミニマックス損失関数により、ディスкриминレータのジョブが非常に簡単な場合、GAN が GAN トレーニングの初期段階で停止する可能性があると記載されています。したがって、この論文では、生成元が log D(G(z))
を最大化するように生成元の損失を変更することを提案しています。
TF-GAN でこの変更を実装する方法については、modified_generator_loss をご覧ください。
Wasserstein 損失
デフォルトでは、TF-GAN は Wasserstein 損失を使用します。
この損失関数は、ディスкриминエータがインスタンスを実際に分類しない GAN スキーム(「Wasserstein GAN」または「WGAN」)の変更に依存しています。インスタンスごとに数値を出力します。この数値は 1 未満または 0 より大きい必要はありません。そのため、インスタンスが本物か偽物かを判断するしきい値として 0.5 を使用できません。弁別子のトレーニングでは、偽のインスタンスよりも実際のインスタンスの出力を大きくしようとします。
本物と偽物を実際に区別できないため、WGAN の識別子は「識別子」ではなく「批評家」と呼ばれます。この区別は理論的には重要ですが、実用的には、損失関数への入力が確率である必要がないことを認識しているものとして扱うことができます。
損失関数自体は、一見シンプルに見えます。
批評家による評価: D(x) - D(G(z))
ディスクリミネータは、この関数を最大化しようとします。つまり、実際のインスタンスでの出力と偽のインスタンスでの出力の差を最大化しようとします。
Generator Loss: D(G(z))
生成ツールは、この関数を最大化しようとします。つまり、偽のインスタンスに対する識別子の出力を最大化しようとします。
これらの関数では、次の処理を行います。
D(x)
は、実際のインスタンスのクリティックの出力です。G(z)
は、ノイズ z が指定された場合の生成元の出力です。D(G(z))
は、偽のインスタンスに対するクリティックの出力です。- クリティック D の出力は 1 ~ 0 の範囲に限られません。
- これらの式は、実際の分布と生成された分布の間の地球移動距離から導出されます。
TF-GAN での実装については、wasserstein_generator_loss と wasserstein_discriminator_loss をご覧ください。
要件
Wasserstein GAN(WGAN)の理論的根拠では、制約付きの範囲内に収まるように、GAN 全体の重みをクリップする必要があります。
利点
Wasserstein GAN は、ミニマックスベースの GAN よりも行き詰まりの影響を受けにくく、勾配消失の問題を回避できます。土砂移動距離には、真の指標であるという利点もあります。これは、確率分布空間における距離の測定値です。クロスエントロピーは、この意味での指標ではありません。