ChainerでVAEを作るときにloss関数をbernoulli_nllではなくMSEを使うと学習が進まない。
Chainerユーザーです。Chainerを使ってVAEを実装しました。参考にしたURLは
・Variational Autoencoder徹底解説
・AutoEncoder, VAE, CVAEの比較
・PyTorch+Google ColabでVariational Auto Encoderをやってみた
などです。実装したコードのコアになる部分は以下の通りです。
class VAE(chainer.Chain):
def __init__(self, n_in, n_latent, n_h, act_func=F.tanh):
super(VAE, self).__init__()
self.act_func = act_func
with self.init_scope():
# encoder
self.le1 = L.Linear(n_in, n_h)
self.le2 = L.Linear(n_h, n_h)
self.le3_mu = L.Linear(n_h, n_latent)
self.le3_ln_var = L.Linear(n_h, n_latent)
# decoder
self.ld1 = L.Linear(n_latent, n_h)
self.ld2 = L.Linear(n_h, n_h)
self.ld3 = L.Linear(n_h, n_in)
def __call__(self, x, sigmoid=True):
return self.decode(self.encode(x)[0], sigmoid)
def encode(self, x):
h1 = self.act_func(self.le1(x))
h2 = self.act_func(self.le2(h1))
mu = self.le3_mu(h2)
ln_var = self.le3_ln_var(h2)
return mu, ln_var
def decode(self, z, sigmoid=True):
h1 = self.act_func(self.ld1(z))
h2 = self.act_func(self.ld2(h1))
h3 = self.ld3(h2)
if sigmoid:
return F.sigmoid(h3)
else:
return h3
def get_loss_func(self, C=1.0, k=1):
def lf(x):
mu, ln_var = self.encode(x)
batchsize = len(mu.data)
# reconstruction error
rec_loss = 0
for l in six.moves.range(k):
z = F.gaussian(mu, ln_var)
z.name = "z"
rec_loss += F.bernoulli_nll(x, self.decode(z, sigmoid=False)) / (k * batchsize)
self.rec_loss = rec_loss
self.rec_loss.name = "reconstruction error"
self.latent_loss = C * gaussian_kl_divergence(mu, ln_var) / batchsize
self..name = "latent loss"
self.loss = self.rec_loss + self.latent_loss
self.loss.name = "loss"
return self.loss
return lf
rec_lossは再構成誤差、すなわち入力と出力がどの程度等しいかを表していて、latent_lossの方は特徴量空間における分布が正規分布からどれくらいことなるを表す誤差だと認識しています。
MNISTで実験してみた結果、
1.lossが減少していく
2.再構成がちゃんと行われる(input画像が3なら、output画像も3になっている)
3.特徴量空間でランダムサンプリングを行った結果、ちゃんと数字が出力される。
などが確かめられました。
疑問
ところで、疑問なのですが、rec_lossは再構成誤差なので、素朴には平均二乗誤差をつかうのが自然だと思われます。そこでrec_lossの部分を
rec_loss += F.mean_squared_error(x, self.decode(z)) / k
と書き換え、ほかの条件は全部そのままで(他の部分は一切書き換えずに)、実験すると
1. rec_lossが2epoch目以降、全く減少しない。
2. 再構成が行われない。適当なinputを与えても、意味のないoutput画像が得られる。(ちなみにinput画像の種類によらずoutput画像は一定のようです)
などの結果が得られて、学習がうまくいっていっていないようです。
これはなぜでしょうか。
追記
bernoulli_nllはデフォルトではすべて合計する一方で、mean_squared_errorは二乗誤差をバッチとピクセルの両方で平均するので、再構成項が過小評価されてしまっているかもしれないと考えました。
MNSITは28×28の画像なので、MSEを用いる際には28×28×F.mean_squared_error(x, decode(z))とすれば良いと思い、試してみましたが結果は変わりませんでした。
追記2
chainerでMSEを使ったVAEの実装を行っているコードを見つけました。
(https://github.com/maguro27/VAE-CIFAR10_chainer/blob/master/VAE_CIFAR10.ipynb)
なぜ、このコードでは動いて、私の上のコードでは学習がうまくいかないのでしょうか。
追記3
optimizerの問題ではないかとの指摘を受けて、Adam,AdaDelta,SGDで試しましたが結果は変わらず…
追記4
レスポンスがないので以下2か所でも同様の質問を行いました。
- https://stackoverflow.com/questions/56469800/vae-does-not-learn-when-i-change-reconstruction-loss-functions-f-bernoulli-nll-t
- https://teratail.com/questions/192884
解決策?
rec_loss += F.mean_squared_error(x, self.decode(z)) / k
を
rec_loss += F.mean(F.sum((x - self.decode(z)) ** 2, axis=1))
に変えるとうまくいきます。でもなぜだ…?