Chainerユーザーです。Chainerを使ってVAEを実装しました。参考にしたURLは

Variational Autoencoder徹底解説
AutoEncoder, VAE, CVAEの比較

PyTorch+Google ColabでVariational Auto Encoderをやってみた

などです。実装したコードのコアになる部分は以下の通りです。

class VAE(chainer.Chain):

    def __init__(self, n_in, n_latent, n_h, act_func=F.tanh):
        super(VAE, self).__init__()
        self.act_func = act_func
        with self.init_scope():
            # encoder
            self.le1        = L.Linear(n_in, n_h)
            self.le2        = L.Linear(n_h,  n_h)
            self.le3_mu     = L.Linear(n_h,  n_latent)
            self.le3_ln_var = L.Linear(n_h,  n_latent)

            # decoder
            self.ld1 = L.Linear(n_latent, n_h)
            self.ld2 = L.Linear(n_h,      n_h)
            self.ld3 = L.Linear(n_h,      n_in)

    def __call__(self, x, sigmoid=True):
        return self.decode(self.encode(x)[0], sigmoid)

    def encode(self, x):
        h1 = self.act_func(self.le1(x))
        h2 = self.act_func(self.le2(h1))
        mu = self.le3_mu(h2)
        ln_var = self.le3_ln_var(h2) 
        return mu, ln_var

    def decode(self, z, sigmoid=True):
        h1 = self.act_func(self.ld1(z))
        h2 = self.act_func(self.ld2(h1))
        h3 = self.ld3(h2)
        if sigmoid:
            return F.sigmoid(h3)
        else:
            return h3

    def get_loss_func(self, C=1.0, k=1):
        def lf(x):
            mu, ln_var = self.encode(x)
            batchsize = len(mu.data)
            # reconstruction error
            rec_loss = 0
            for l in six.moves.range(k):
                z = F.gaussian(mu, ln_var)
                z.name = "z"
                rec_loss += F.bernoulli_nll(x, self.decode(z, sigmoid=False)) / (k * batchsize)
            self.rec_loss = rec_loss
            self.rec_loss.name = "reconstruction error"
            self.latent_loss = C * gaussian_kl_divergence(mu, ln_var) / batchsize
            self..name = "latent loss"
            self.loss = self.rec_loss + self.latent_loss
            self.loss.name = "loss"
            return self.loss
        return lf

rec_lossは再構成誤差、すなわち入力と出力がどの程度等しいかを表していて、latent_lossの方は特徴量空間における分布が正規分布からどれくらいことなるを表す誤差だと認識しています。

MNISTで実験してみた結果、
1.lossが減少していく
2.再構成がちゃんと行われる(input画像が3なら、output画像も3になっている)
3.特徴量空間でランダムサンプリングを行った結果、ちゃんと数字が出力される。
などが確かめられました。

疑問

ところで、疑問なのですが、rec_lossは再構成誤差なので、素朴には平均二乗誤差をつかうのが自然だと思われます。そこでrec_lossの部分を

rec_loss += F.mean_squared_error(x, self.decode(z)) / k

と書き換え、ほかの条件は全部そのままで(他の部分は一切書き換えずに)、実験すると
1. rec_lossが2epoch目以降、全く減少しない。
2. 再構成が行われない。適当なinputを与えても、意味のないoutput画像が得られる。(ちなみにinput画像の種類によらずoutput画像は一定のようです)
などの結果が得られて、学習がうまくいっていっていないようです。

これはなぜでしょうか。

追記

bernoulli_nllはデフォルトではすべて合計する一方で、mean_squared_errorは二乗誤差をバッチとピクセルの両方で平均するので、再構成項が過小評価されてしまっているかもしれないと考えました。

MNSITは28×28の画像なので、MSEを用いる際には28×28×F.mean_squared_error(x, decode(z))とすれば良いと思い、試してみましたが結果は変わりませんでした。

追記2

chainerでMSEを使ったVAEの実装を行っているコードを見つけました。
https://github.com/maguro27/VAE-CIFAR10_chainer/blob/master/VAE_CIFAR10.ipynb
なぜ、このコードでは動いて、私の上のコードでは学習がうまくいかないのでしょうか。

追記3

optimizerの問題ではないかとの指摘を受けて、Adam,AdaDelta,SGDで試しましたが結果は変わらず…

追記4

レスポンスがないので以下2か所でも同様の質問を行いました。

解決策?

rec_loss += F.mean_squared_error(x, self.decode(z)) / k 

rec_loss += F.mean(F.sum((x - self.decode(z)) ** 2, axis=1))

に変えるとうまくいきます。でもなぜだ…?