BatchNorm1d нуждается в двумерном вводе?

Я хочу исправить проблему в PyTorch. Я написал следующий код, который изучает синусоидальные функции в качестве учебника.

import torch
from torch import nn
from torch import optim
from torch.autograd import Variable as V
from torch.utils.data import TensorDataset, DataLoader
import numpy as np

# y=sin(x1)
numTrain = 512
numTest = 128
noiseScale = 0.01
PI2 = 3.1415 * 2
X_train = np.random.rand(numTrain,1) * PI2
y_train = np.sin(X_train) + np.random.randn(numTrain,1) * noiseScale + 1.5
X_test  = np.random.rand(numTest,1) * PI2
y_test  = np.sin(X_test) + np.random.randn(numTest,1) * noiseScale

# Construct DataSet
X_trainT = torch.Tensor(X_train)
y_trainT = torch.Tensor(y_train)
X_testT = torch.Tensor(X_test)
y_testT = torch.Tensor(y_test)
ds_train = TensorDataset(X_trainT, y_trainT)
ds_test = TensorDataset(X_testT, y_testT)

# Construct DataLoader
loader_train = DataLoader(ds_train, batch_size=64, shuffle=True)
loader_test = DataLoader(ds_test, batch_size=64, shuffle=False)

# Construct network
net = nn.Sequential(
    nn.Linear(1,10),
    nn.ReLU(),
    nn.BatchNorm1d(10),
    nn.Linear(10,5),
    nn.ReLU(),
    nn.BatchNorm1d(5),
    nn.Linear(5,1),
)
optimizer = optim.Adam(net.parameters())
loss_fn = nn.SmoothL1Loss()

# Training
losses = []
net.train()
for epoc in range(100):
    for data, target in loader_train:
        y_pred = net(data)
        loss = loss_fn(target,y_pred)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.data)


# evaluation
%matplotlib inline
from matplotlib import pyplot as plt

#plt.plot(losses)
plt.scatter(X_train, y_train)

net.eval()
sinsX = []
sinsY = []
for t in range(128):
    x = t/128 * PI2
    output = net(V(torch.Tensor([x])))
    sinsX.append(x)
    sinsY.append(output.detach().numpy())
plt.scatter(sinsX,sinsY)

Обучение выполняется без ошибок, но следующая строка вызвала ошибку «ожидаемый ввод 2D или 3D (получен ввод 1D)»

output = net(V(torch.Tensor([x])))

Эта ошибка не возникает, если не используется BatchNorm1d(). Я чувствую себя странно, потому что вход 1D.

Как это исправить?

Спасибо.

Обновление: как я исправил

arr = np.array([x])
output = net(V(torch.Tensor(arr[None,...])))

Вы должны взглянуть на документацию, там вы можете увидеть, какой ввод ожидает BatchNorm1d. pytorch.org/docs/stable/nn.html#torch.nn.BatchNorm1d

MBT 24.03.2019 08:03
Оптимизация производительности модели: Руководство по настройке гиперпараметров в Python с Keras
Оптимизация производительности модели: Руководство по настройке гиперпараметров в Python с Keras
Настройка гиперпараметров - это процесс выбора наилучшего набора гиперпараметров для модели машинного обучения с целью оптимизации ее...
Развертывание модели машинного обучения с помощью Flask - Angular в Kubernetes
Развертывание модели машинного обучения с помощью Flask - Angular в Kubernetes
Kubernetes - это портативная, расширяемая платформа с открытым исходным кодом для управления контейнерными рабочими нагрузками и сервисами, которая...
Udacity Nanodegree Capstone Project: Классификатор пород собак
Udacity Nanodegree Capstone Project: Классификатор пород собак
Вы можете ознакомиться со скриптами проекта и данными на github .
Определение пород собак с помощью конволюционных нейронных сетей (CNN)
Определение пород собак с помощью конволюционных нейронных сетей (CNN)
В рамках финального проекта Udacity Data Scietist Nanodegree я разработал алгоритм с использованием конволюционных нейронных сетей (CNN) для...
Почему Python - идеальный выбор для проекта AI и ML
Почему Python - идеальный выбор для проекта AI и ML
Блог, которым поделился Harikrishna Kundariya в нашем сообществе Developer Nation Community.
1
1
1 213
1
Перейти к ответу Данный вопрос помечен как решенный

Ответы 1

Ответ принят как подходящий

При работе с одномерными сигналами pyTorch фактически ожидает двумерных тензоров: первое измерение — это «мини-пакетное» измерение. Следовательно, вы должны оценить свою сеть на пакете с одним сигналом 1D:

output - net(V(torch.Tensor([x[None, ...]]))

Убедитесь, что вы установили свою сеть в режим «eval» перед ее оценкой:

net.eval()

Другие вопросы по теме