DCGANを行う際に「'numpy.ndarray' object is not callable」というエラーが発生
DCGANで画像の自動生成を行う際に「'numpy.ndarray' object is not callable」というエラーが発生します。コードは下のものです。GoogleColaboratoryで実行しています。
# 本物データをGeneratorで生成したデータのスケールを-1~1で揃える
def scale(x, feature_ranges=(-1, 1)):
# 0~1に変換
x = ((x - x.min()) / (255 - x.min()))
# -1~1に変換
min, max = feature_ranges
x = x * (max - min) + min
return x
class Dataset:
# val_fracでテストデータを学習中と学習後用に分離する
# スケール関数は上のものを使うためscale_func=None
def __init__(self, shuffle= False, scale_func=None):
self.test_x, self.valid_x = test_img2, valid_img2
self.test_y, self.valid_y = test_index, valid_index
self.train_x, self.train_y = train_img2, train_index
if scale_func is None:
self.scaler = scale
else:
self.scaler = scale_func
self.shuffle = shuffle
# ミニバッチ生成の定義
def batches(self, batch_size):
if (self.shuffle).any():
idx = np.arange(len(dataset.train_x))
np.random.shuffle(idx)
self.train_x = self.train_x[idx]
self.train_y = self.train_y[idx]
n_batches = len(self.train_y) // batch_size
for ii in range(0, len(self.train_y), batch_size):
x = self.train_x[ii:ii+batch_size]
y = self.train_y[ii:ii+batch_size]
yield self.scaler(x), y
(中略)
# トレーニングの実行
dataset = Dataset(train_img2, test_img2)
losses, samples = train(net, dataset, epochs, batch_size, figsize=(10, 5))
これを実行すると次のようなエラーが出ます。
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-16-79c936bb63bd> in <module>()
2 dataset = Dataset(train_img2, test_img2)
3
----> 4 losses, samples = train(net, dataset, epochs, batch_size, figsize=(10, 5))
1 frames
<ipython-input-14-07c181927757> in train(net, dataset, epochs, batch_size, print_every, show_every, figsize)
14 for e in range(epochs):
15 # バッチで取り出してパラメータの更新を行う
---> 16 for x, y in dataset.batches(batch_size):
17 # for文のたびにstep数を1増加
18 steps += 1
<ipython-input-6-f57b4d686916> in batches(self, batch_size)
26 y = self.train_y[ii:ii+batch_size]
27
---> 28 yield self.scaler(x), y
TypeError: 'numpy.ndarray' object is not callable
定義したself.scaler
を使ってxとyを返したいのですが,このようなエラーが出ます。
どのように修正すればこのエラーを解決できるのでしょうか?