現在Tensorflowを用いてオートエンコーダの構成を行っていて、中間層の活性化関数をステップ関数にしてみたのですが、MNISTの画像を用いたところ学習がうまくいっているように見られました。
しかし、ステップ関数はニューラルネットワークにおいて勾配を持たないため学習はうまくいかないと習っていたので何故うまくいっているのかが分かりません。その点について教えていただきたいです。以下が実際のコードです。

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
import tensorflow as tf

# Import data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("../MNIST_data/", one_hot=True)


# Variables
x = tf.placeholder("float", [None, 784])
y_ = tf.placeholder("float", [None, 10])
h=128 #中間層のノード数
w_enc = tf.Variable(tf.random_normal([784, h], mean=0.0, stddev=0.05))
w_dec = tf.Variable(tf.random_normal([h, 784], mean=0.0, stddev=0.05))
# w_dec = tf.transpose(w_enc) # if you use tied weights
b_enc = tf.Variable(tf.zeros([h]))
b_dec = tf.Variable(tf.zeros([784]))

# Create the model
def model(X, w_e, b_e, w_d, b_d):
    encoded = tf.round(tf.sigmoid(tf.matmul(X, w_e) + b_e))
#中間層の処理
    decoded = tf.sigmoid(tf.matmul(encoded, w_d) + b_d)
#出力層の処理
    return encoded, decoded

encoded, decoded = model(x, w_enc, b_enc, w_dec, b_dec)

# Cost Function basic term
cross_entropy = -1. * x * tf.log(decoded) - (1. - x) * tf.log(1. - decoded)
loss = tf.reduce_mean(cross_entropy)
train_step = tf.train.AdamOptimizer(1e-4).minimize(loss)

# Train
init = tf.initialize_all_variables()

with tf.Session() as sess:
    sess.run(init)
    print('Training...')
    for i in range(11001):
        batch_xs, batch_ys = mnist.train.next_batch(256)
        train_step.run({x: batch_xs, y_: batch_ys})

        if i % 1000 == 0:
            train_loss = loss.eval({x: batch_xs, y_: batch_ys})
            print('  step, loss = %6d: %6.3f' % (i, train_loss))

    # generate decoded image with test data
    test_fd = {x: mnist.test.images, y_: mnist.test.labels}
    decoded_imgs = decoded.eval(test_fd)
    print('loss (test) = ', loss.eval(test_fd))

x_test = mnist.test.images

n = 10  # 
plt.figure(figsize=(20, 4))
for i in range(n):
    # display original
    ax = plt.subplot(2, n, i + 1)
    plt.imshow(x_test[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # display reconstruction
    ax = plt.subplot(2, n, i + 1 + n)
    plt.imshow(decoded_imgs[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.savefig('mnist_ae_step.png')

このコード自体が自分で一から作ったものではないので、間違いが多いかもしれませんが、よろしければ教えていただきたいです。お願いします。