損失関数の呼ばれた回数を知りたいです
質問
kerasで損失関数を自作しました.その中で,損失関数が呼ばれた回数を使いたいと思いまして,以下のようにcountに逐次+1をするようにコードを書きました.しかし,最後のprint(count)の出力が1になってしまいます.
損失関数は1回しか呼ばれてないのでしょうか.
code
count = 0
def encoder(input_):
d1 = Dense(3, activation='relu', name='encoder_input')(input_)
d2 = Dense(2, 'encoder_output')(d1)
return d2
def decoder(input_):
d1 = Dense(3, activation='relu', name='decoder_input')(input_)
d2 = Dense(2, name='decoder_output')(d1)
return d2
def my_loss_function(y_pred, y_true):
global count
count += 1
return K.mean(K.square(y_pred - y_true), axis=-1)
# input
input = Input(shape=(2,))
# output
output = decoder(encoder(input))
# model
model = Model(input=input, output=output)
model.compile(optimizer='adam', loss='my_loss_function')
model_hist = model.fit(x_train, x_test,
epochs=n_epoch,
batch_size=batch_size,
verbose=verbose,
shuffle=True)
print('count=',count)
#count=1