質問

kerasで損失関数を自作しました.その中で,損失関数が呼ばれた回数を使いたいと思いまして,以下のようにcountに逐次+1をするようにコードを書きました.しかし,最後のprint(count)の出力が1になってしまいます.
損失関数は1回しか呼ばれてないのでしょうか.

code

count = 0

def encoder(input_):
    d1 = Dense(3, activation='relu', name='encoder_input')(input_)
    d2 = Dense(2, 'encoder_output')(d1)
    return d2

def decoder(input_):
    d1 = Dense(3, activation='relu', name='decoder_input')(input_)
    d2 = Dense(2, name='decoder_output')(d1)
    return d2

def my_loss_function(y_pred, y_true):
    global count
    count += 1
    return K.mean(K.square(y_pred - y_true), axis=-1)  

# input
input = Input(shape=(2,))

# output
output = decoder(encoder(input))

# model
model = Model(input=input, output=output)

model.compile(optimizer='adam', loss='my_loss_function')
model_hist = model.fit(x_train, x_test,
                        epochs=n_epoch,
                        batch_size=batch_size,
                        verbose=verbose,
                        shuffle=True)

print('count=',count)
#count=1