イメージネットのバリデーションのラベルの取得方法
pytorchを使ってImagenetのバリデーションセットをダウンロードするとエラーが出ました。
pytorch imagenet datasets
そこで公式を参考に実装してみたのですが、正しいラベルを取得できませんでした。
from __future__ import print_function
import os
import shutil
import tempfile
import torch
from torchvision.datasets.folder import ImageFolder
from torchvision.datasets.utils import check_integrity, download_and_extract_archive, extract_archive, \
verify_str_arg
class ImageNet(ImageFolder):
"""`ImageNet <http://image-net.org/>`_ 2012 Classification Dataset.
Attributes:
classes (list): List of the class name tuples.
class_to_idx (dict): Dict with items (class_name, class_index).
wnids (list): List of the WordNet IDs.
wnid_to_idx (dict): Dict with items (wordnet_id, class_index).
imgs (list): List of (image path, class_index) tuples
targets (list): The class_index value for each image in the dataset
"""
def __init__(self, root, split='val', **kwargs):
root = self.root = os.path.expanduser(root)
wnid_to_classes = self.parse_meta(root)[0]
super(ImageNet, self).__init__(self.split_folder, **kwargs)
self.root = root
self.wnids = self.classes
self.wnid_to_idx = self.class_to_idx
self.classes = [wnid_to_classes[wnid] for wnid in self.wnids]
self.class_to_idx = {cls: idx
for idx, clss in enumerate(self.classes)
for cls in clss}
def parse_devkit(root):
idx_to_wnid, wnid_to_classes = parse_meta(root)
val_idcs = parse_val_groundtruth(root)
val_wnids = [idx_to_wnid[idx] for idx in val_idcs]
return wnid_to_classes, val_wnids
def parse_meta(devkit_root, path='data', filename='meta.mat'):
import scipy.io as sio
metafile = os.path.join(devkit_root, path, filename)
meta = sio.loadmat(metafile, squeeze_me=True)['synsets']
nums_children = list(zip(*meta))[4]
meta = [meta[idx] for idx, num_children in enumerate(nums_children)
if num_children == 0]
idcs, wnids, classes = list(zip(*meta))[:3]
classes = [tuple(clss.split(', ')) for clss in classes]
idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)}
wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)}
return idx_to_wnid, wnid_to_classes
def parse_val_groundtruth(devkit_root, path='data',
filename='ILSVRC2012_validation_ground_truth.txt'):
with open(os.path.join(devkit_root, path, filename), 'r') as txtfh:
val_idcs = txtfh.readlines()
return [int(val_idx) for val_idx in val_idcs]
def prepare_val_folder(folder, wnids):
img_files = sorted([os.path.join(folder, file) for file in os.listdir(folder)])
for wnid in set(wnids):
os.mkdir(os.path.join(folder, wnid))
for wnid, img_file in zip(wnids, img_files):
shutil.move(img_file, os.path.join(folder, wnid, os.path.basename(img_file)))
testset = dataset.ImageFolder('/path/to/data/', IMAGENET_TRANSFORM)
loader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=True, num_workers=2)
dataiter = iter(loader)
images, labels = dataiter.next()
print(labels)
tensor([3, 3, 3, 3])
ILSVRC2012_validation_ground_truth.txt
はあります。
実装方法、参考サイトがあれば教えてください