scikit-learn/k-meansを使った画像のクラスタリングでのエラー
ディレクトリ内のjpg画像ファイルをSURFを用いて、特徴量を抽出し、すべてのSURFをk-means法でグループ化して基本特徴量(visual word)を求め、これを使って画像の局所特徴量リストをbag-of-wordsリストにするプログラムがあります。
試しに、90枚ほどの画像でやるとうまくグループわけができましたが、いざ1900枚ほどの画像でやってみると、以下のエラーメッセージが出てきました。
Traceback (most recent call last):
File "sample.py", line 27, in <module>
c = km.predict(d)
File "C:\Anaconda3\lib\site-packages\sklearn\cluster\k_means_.py", line 1460, in predict
X = self._check_test_data(X)
File "C:\Anaconda3\lib\site-packages\sklearn\cluster\k_means_.py", line 794, in _check_test_data warn_on_dtype=True)
File "C:\Anaconda3\lib\site-packages\sklearn\utils\validation.py", line 407, in check_array context))
ValueError: Found array with 0 sample(s) (shape=(0, 64)) while a minimum of 1 is required.
どうすれば正しく実行できますでしょうか?
以下、コードです。
import mahotas as mh
import numpy as np
from glob import glob
from mahotas.features import surf
from sklearn.cluster import MiniBatchKMeans
from sklearn.feature_extraction.text import TfidfTransformer
picture_category_num = 5
feature_category_num = 128
# image surf
images = glob('./*.jpg')
alldescriptors = []
for im in images:
im = mh.imread(im, as_grey=True)
im = im.astype(np.uint8)
alldescriptors.append(surf.surf(im, descriptor_only=True))
# image surf -> basic feature
concatenated = np.concatenate(alldescriptors)
km = MiniBatchKMeans(feature_category_num)
km.fit(concatenated)
# image surf and basic feature -> features
features = []
for d in alldescriptors:
c = km.predict(d)
features.append(np.array([np.sum(c == ci) for ci in range(feature_category_num)]))
features = np.array(features)
# features -> tfidf
transformer = TfidfTransformer()
tfidf = transformer.fit_transform(features)
tfidf.toarray()
# not use tfidf
# tfidf = features
# categorization
km = MiniBatchKMeans(n_clusters=picture_category_num, init='random', n_init=1, verbose=1)
km.fit(tfidf)
# print result
images = np.array(images)
print('completed')
f = open("result.txt","w")
for i in range(picture_category_num):
print('image category{0}'.format(i), file=f)
print(images[km.labels_ == i], file=f)
else:
f.close()