teratailにても投稿済です
https://teratail.com/questions/225527
概要
物体検出プログラム作成時、独自データセットの中を見ようとして
(first_datum = train_dataset[0])
エラー。参考サイトのプログラムでは動作したため、データセットの中身がおかしい?とかんがえられます。まったくわからないため教えていただきたいです。よろしくお願いします。
エラーの出たプログラム
import chainer
import cupy
import chainercv
import matplotlib
import os
import xml.etree.ElementTree as ET
import numpy as np
from chainercv.datasets import VOCBboxDataset
bccd_labels = ('ROCK', 'MOUNTAIN')#変更箇所
class BCCDDataset(VOCBboxDataset):
def _get_annotations(self, i):
id_ = self.ids[i]
anno = ET.parse(
os.path.join(self.data_dir, 'Annotations', id_ + '.xml'))
bbox = []
label = []
difficult = []
for obj in anno.findall('object'):
bndbox_anno = obj.find('bndbox')
bbox.append([
int(bndbox_anno.find(tag).text) - 1
for tag in ('ymin', 'xmin', 'ymax', 'xmax')])
name = obj.find('name').text.lower().strip()
label.append(bccd_labels.index(name))
print(bccd_labels.index(obj.find('name').text.lower().strip()))
bbox = np.stack(bbox).astype(np.float32)
label = np.stack(label).astype(np.int32)
difficult = np.array(difficult, dtype=np.bool)
return bbox, label, difficult
train_dataset = BCCDDataset('rock_detect_dataset/BCCD', 'train')
valid_dataset = BCCDDataset('rock_detect_dataset/BCCD', 'val')
test_dataset = BCCDDataset('rock_detect_dataset/BCCD', 'test')
print('Number of images in "train" dataset:', len(train_dataset))
print('Number of images in "valid" dataset:', len(valid_dataset))
print('Number of images in "test" dataset:', len(test_dataset))
first_datum = train_dataset[0]
参考サイト・コード
!pip install chainercv # ChainerCVのインストール
import chainer
import cupy
import chainercv
import matplotlib
!if [ ! -d BCCD_Dataset ]; then git clone https://github.com/Shenggan/BCCD_Dataset.git; fi
import os
import xml.etree.ElementTree as ET
import numpy as np
from chainercv.datasets import VOCBboxDataset
bccd_labels = ('rbc', 'wbc', 'platelets')
class BCCDDataset(VOCBboxDataset):
def _get_annotations(self, i):
id_ = self.ids[i]
anno = ET.parse(
os.path.join(self.data_dir, 'Annotations', id_ + '.xml'))
bbox = []
label = []
difficult = []
for obj in anno.findall('object'):
bndbox_anno = obj.find('bndbox')
bbox.append([
int(bndbox_anno.find(tag).text) - 1
for tag in ('ymin', 'xmin', 'ymax', 'xmax')])
name = obj.find('name').text.lower().strip()
label.append(bccd_labels.index(name))
bbox = np.stack(bbox).astype(np.float32)
label = np.stack(label).astype(np.int32)
difficult = np.array(difficult, dtype=np.bool)
return bbox, label, difficult
train_dataset = BCCDDataset('BCCD_Dataset/BCCD', 'train')
valid_dataset = BCCDDataset('BCCD_Dataset/BCCD', 'val')
test_dataset = BCCDDataset('BCCD_Dataset/BCCD', 'test')
first_datum = train_dataset[0]
xml例
<annotation>
<folder>JPGImages</folder>
<filename>hogehoge.jpg</filename>
<path>hogehoge</path>
<source>
<database>Unknown</database>
</source>
<size>
<width>720</width>
<height>1080</height>
<depth>3</depth>
</size>
<segmented>0</segmented>
<object>
<name>MOUNTAIN</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>242.6123427801162</xmin>
<ymin>279.4234910191407</ymin>
<xmax>720</xmax>
<ymax>706.8194993286604</ymax>
</bndbox>
</object>
</annotation>
エラーコード
ValueErrorTraceback (most recent call last)
<ipython-input-59-8cd07929f6af> in <module>
----> 1 first_datum = train_dataset[0]
2 print(first_datum[0].shape, first_datum[0].dtype)
/usr/local/lib/python3.6/dist-packages/chainer/dataset/dataset_mixin.py in __getitem__(self, index)
65 return [self.get_example(i) for i in index]
66 else:
---> 67 return self.get_example(index)
68
69 def __len__(self):
/usr/local/lib/python3.6/dist-packages/chainercv/chainer_experimental/datasets/sliceable/sliceable_dataset.py in get_example(self, index)
96 if isinstance(self.keys, tuple):
97 return self.get_example_by_keys(
---> 98 index, tuple(range(len(self.keys))))
99 else:
100 return self.get_example_by_keys(index, (0,))[0]
/usr/local/lib/python3.6/dist-packages/chainercv/chainer_experimental/datasets/sliceable/getter_dataset.py in get_example_by_keys(self, index, key_indices)
87 _, getter_index, key_index = self._keys[key_index]
88 if getter_index not in cache:
---> 89 cache[getter_index] = self._getters[getter_index](index)
90 if key_index is None:
91 example.append(cache[getter_index])
<ipython-input-55-501a46a571b5> in _get_annotations(self, i)
33 for tag in ('ymin', 'xmin', 'ymax', 'xmax')])
34 name = obj.find('name').text.lower().strip()
---> 35 label.append(bccd_labels.index(name))
36 print(bccd_labels.index(obj.find('name').text.lower().strip()))
37 bbox = np.stack(bbox).astype(np.float32)
ValueError: tuple.index(x): x not in tuple