infoGANを実装した(い)
infoGANの論文を読み,MNIST用の実装をPyTorchで行った記録です.
論文は2016年6月に出ているので1年ほど前のもの.
タイトルを日本語訳すると,「情報量を最大化する敵対的生成ネットワークによる解釈可能な表現の学習」でしょうか.
infoGANの特徴
・GAN
・明快な表現を獲得する
・教師なし学習
・相互情報量の最大化(隠れ符号と生成画像間の)
・DCGANの一部を変えるだけで作れ,訓練も簡単
相互情報量は,2つの確率変数のうち1つがわかった時に,もう1つの確率変数に関してどれほど推測できるかを示すので,2つが独立な時には相互情報量は0になります.
今回のinfoGANでは,その相互情報量を最大化するので,隠れ符号と生成画像の相互依存性を高めるように学習させようということになります.
実装に関して
・ネットワーク
DCGANをベースにして,論文のAppendix通りにQネットワークを付け足す.
隠れ符号cを出力するQnetとDnetは畳み込み層全てを共有する.
出力層に関して,論文中に記載されていなかったように思うので,雑に設定してしまった.
Dの出力層は,ノード2つ,ソフトマックス関数
Qの出力層は,ノード12つ,
うちカテゴリカル符号は10次元分,
うちコンティニュアス符号は2次元分.
Gの入力には,62次元のzと,12次元のcを単純につなぎ合わせた74次元のノイズ.
・損失関数
Dの損失関数は,2つのノードをからソフトマックスをへてクロスエントロピーを計算.realの時には1を出力,fakeの時には0を出力できているか.
(F.CrossEntropyLoss でlogSoftMaxが行われるから2回もソフトマックス計算している??)
Gの損失関数は,Gを経由した生成画像がDでrealつまり1と出力されるかでクロスエントロピー.
Qの損失関数は,カテゴリカル符号は10次元のone-hotで評価.ターゲットにしている数字を示すラベルが出てきているかでクロスエントロピー誤差.
コンティニュアス符号は,論文中にこのような記述.
For continuous latent codes, we parameterize the approximate posterior through a diagonal Gaussian distribution,
出力された符号は,対角ガウス分布に従うようにパラメタライズするようです.
正直よくわからなかったので,
GitHub - hvy/chainer-infogan: Chainer implementation of InfoGAN
こちらの方のchainerによる実装を参考にさせてもらったところ,gaussian_nllというロス関数を使っていたので,真似してやり過ごしました.
・backward
どこで .detach() するかは重要ですが,DCGANと同じでいいはずで,Dをバックワードする時にGに行かないようにする.GはDごとバックワード.そして,出力されたcはGもDも通って来ていてそのままバックワードすることにしました.
結果
・カテゴリカル符号(10次元)をone-hotで入力した時の出力
入力[1,0,0,0,0,0,0,0,0,0], [0,1,0,0,0,0,0,0,0,0],,,,
順番は揃っていませんが,数字を各カテゴリに分けることができています.
・コンティニュアス符号(1次元が2つ)
各行ごとにカテゴリ(数字の種類)を固定して,一つのcontinuous cを変化.
数字の傾きが変化していることから,一つのcはローテーションという表現を獲得していることがわかります.
もう一つのcontinuous cを変化.
1がはっきりしていますが,太いものから細いものへと変化しています.cが太さという表現を獲得しています.
ここまで200epochほどで学習が少ないかもしれませんが,DCGANなどでは解釈が難しかったzですが,infoGANによって1つの符号が明快な表現を意味するようになったことを確認できました.
改善すべき点は多いですがコードはここにあります.
自己符号化器,AutoEncoderの実装
機械学習プロフェッショナルシリーズの「深層学習」のChapter5を参考に,PyTorchでAutoEncoderの実装を行いました.
パラメータとしては,
入出力層が28x28次元,
中間層が100次元,
(28x28 -> 100 -> 28x28)
中間層の活性化関数はReLU,
出力層の活性化関数は恒等写像,
重みはガウス分布(σ=0.01)で初期化,
SGD(重み減衰λ=0.1,モメンタムµ=0.5
Loss関数は二乗誤差
にしました.
SGDよりもAdamの方がLossが落ちたのでAdamに変更しました.
結果としては,
このようになり,数字は認識できます.画像全体として(背景が?)白くなっているのは,重み0が濃淡256階層の中央に来るように調整することで治るのかもしれません.
中間層のユニットを100個よりも大きくすると,よりLossも下がり,よりくっきりとした画像が得られました.中間層のユニットの数が特徴の表現力を表していることを実感できます.
コードはここにあります.