Autoencoderの中間層ノードを増やして出力の再現精度を上げようとしたのですが、うまく行かなかったので質問させていただきます。

Autoencoder は例えば、784 -> 32 -> 784 な構造で組めば、32次元に圧縮された特徴情報が得られる、というような手法であると理解しています。

そこで、784 -> 32 -> 784 な Autoencoder が32次元から784次元の情報を再現するのなら、784 -> 784 -> 784 なネットワークはより上手く、というよりほぼ損失なく入力を復元する恒等写像を作るのでは?と考え、Fashion-MNISTデータセットを使ったコードを書いてみたのですが、出力される画像は期待するほど鮮明になりませんでした。(オリジナルのMNISTはそこそこ鮮明になりました)

Fashion-MNIST

もちろん、中間層のノードを入出力より少なく設定しなければ Autoencoder を使う意味はほぼ無いという事は理解しているのですが、中間層のノードを増やしても恒等写像を高精度に学習できない理由が何であるか知りたいのです。

上記のFashion-MNISTの例のように恒等写像ですら高精度に出来ない=次元を減らさないネットワークを使っても上手く情報を残すとは言いがたい(壊してしまう?)状態であるなら、次元削減を目的に使った場合にも精度の悪い特徴を残してしまうだけではないのでしょうか?

原理的には、784次元の情報を再現するには784次元以上の情報があれば十分じゃないのかと思うのですが、高精度にならない理由やより良い実装、参考となる資料などをご存じでしたらお答え頂けると有り難いです。

追記:上記画像を生成する際に使ったコード全文(Python3 & Keras)を追記しておきます。

from keras.datasets import fashion_mnist
from keras.layers import Input, Dense
from keras.models import Model

import matplotlib.pyplot as plt

(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()

x_train = x_train.reshape(60000, 784)
x_test = x_test.reshape(10000, 784)
x_train = x_train.astype("float32")
x_test = x_test.astype("float32")
x_train /= 255.0
x_test /= 255.0

input_img = Input(shape=(784,))
encoded = Dense(784, activation="relu")(input_img)
decoded = Dense(784, activation="sigmoid")(encoded)

autoencoder = Model(inputs=input_img, outputs=decoded)
autoencoder.compile(optimizer="adam", loss="binary_crossentropy")
autoencoder.fit(x_train, x_train,
                epochs=5,
                batch_size=128,
                shuffle=True,
                validation_data=(x_test, x_test))

# テスト画像を変換
decoded_imgs = autoencoder.predict(x_test)

# 何個表示するか
n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
    # オリジナルのテスト画像を表示
    ax = plt.subplot(2, n, i+1)
    plt.imshow(x_test[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # 変換された画像を表示
    ax = plt.subplot(2, n, i+1+n)
    plt.imshow(decoded_imgs[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()