以下の画像を入力として数値を出力するCNN回帰モデルを作りたい。
画像の下に記載されている数字がラベルになります。
inputs
このように差がはっきりと分かる入力画像に対してだと、ほぼ100%の正解率になると思っているのですが、
学習結果が収束しません。
改善点をアドバイスいただけるとありがたいです。

発生している問題

損失が以下のように推移し、約60epoch移行は精度が向上しません。
losses
200epoch目の結果は以下のような状態で、正解ラベルが0.1から0.9の間だとすると、ほとんど正解とは言えない状況です。
(outputsがモデルの出力結果、labelsが正解データ)

outputs: tensor([[0.4102]], grad_fn=<AddmmBackward>)
labels: tensor([0.9000])
loss: 0.2399456799030304
outputs: tensor([[0.5098]], grad_fn=<AddmmBackward>)
labels: tensor([0.1000])
loss: 0.16796395182609558
outputs: tensor([[0.2965]], grad_fn=<AddmmBackward>)
labels: tensor([0.6000])
loss: 0.09208495169878006
outputs: tensor([[0.3554]], grad_fn=<AddmmBackward>)
labels: tensor([0.3000])
loss: 0.003064962802454829
[200,     4] loss: 0.126

ソースコード

  • 画像を一枚ずつロードし、2乗誤差を計算してモデルをアップデート
  • モデルはAlexnetの最後にfc層を付け加えて、回帰モデルにしたもの
  • optimizerはAdamを使用
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import csv

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader, random_split

from models.alexnet_reg import AlexNet


class ImageDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        self.csv = pd.read_csv(csv_file)
        self.transform = transform

    def __len__(self):
        return len(self.csv)

    def __getitem__(self, idx):
        img_path = self.csv.at[idx.item(), 'img']
        label = self.csv.at[idx.item(), 'mu']
        img = Image.open(img_path)

        if self.transform:
            img = self.transform(img)

        return img, label


def train(model, train_dataset):
    n_epoch = 200
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=1)

    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters())
    model.train()
    losses = []

    for epoch in range(n_epoch):
        running_loss = 0.0
        for i, data in enumerate(train_loader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            labels = labels.type(torch.FloatTensor)

            outputs = model(inputs)
            print('outputs:', outputs)
            print('labels:', labels)
            loss = criterion(outputs, labels)
            print('loss:', loss.item())
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            # if i % 10 == 9:
            #     print('[%d, %5d] loss: %.3f' %
            #           (epoch + 1, i + 1, running_loss / 10))
            #     running_loss = 0.0

        print('[%d, %5d] loss: %.3f' %
              (epoch + 1, i + 1, running_loss / 4))
        losses.append(running_loss/4)

    with open('losses.csv', 'w') as f:
        writer = csv.writer(f, lineterminator='/n')
        writer.writerow(losses)

    plt.plot(losses)
    plt.title('MSELoss')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.show()

def main():
    transform = transforms.Compose(
        [transforms.Resize(224),
         transforms.ToTensor()]
    )
    dataset = ImageDataset(csv_file='./data/dataset00/dataset00.csv', transform=transform)

    train_size = int(1.0 * len(dataset))
    valid_size = len(dataset) - train_size
    train_dataset, valid_dataset = random_split(dataset, [train_size, valid_size])
    print('train_size:', train_size, ', valid_size:', valid_size)

    model = AlexNet()
    train(model, train_dataset)


if __name__ == '__main__':
    main()
import torch.nn as nn
from torchvision import models


class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet, self).__init__()
        self.alexnet = models.alexnet(pretrained=True)
        self.classifier2 = nn.Sequential(
            nn.Dropout(),
            nn.Linear(1000, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(256, 64),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(64, 16),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(16, 1)
        )

    def forward(self, x):
        x = self.alexnet(x)
        x = self.classifier2(x)
        return x

その他試したこと

ラベルをすべて正の数値にしてみたところ、より発散した。
losses_posi

バージョン

torch==1.0.1.post2
torchvision==0.2.2.post3