Pytorch backprop медленнее по сравнению с Tensorflow?

Я реализовал простую сеть DDQN в pytorch и tensorflow. Сеть довольно мелкая. Хотя прямой проход намного быстрее в PyTorch по сравнению с TF, шаг обратного распространения намного медленнее по сравнению с TF. Оба шага обратного распространения были выполнены на ЦП. Есть идеи, как его улучшить.

Сетевая часть:

def __init__(self, hidden_size_IP=100, hidden_size_rest=100, alpha=0.01, state_size=27, action_size=8,
             learning_rate=1e-6):
    super().__init__()

    # build hidden layers
    self.l1 = nn.Sequential(nn.Linear(in_features=500, out_features=400),
                            nn.LeakyReLU(negative_slope=alpha))
    self.l2 = nn.Sequential(nn.Linear(in_features=400, out_features=200),
                            nn.LeakyReLU(negative_slope=alpha))
    self.l3 = nn.Sequential(nn.Linear(in_features=200, out_features=200),
                            nn.LeakyReLU(negative_slope=alpha))
    # build output layer
    self.Qval = nn.Linear(in_features=200, out_features=24)

def forward(self, observation):
    if isinstance(observation, np.ndarray):
        observation = torch.from_numpy(observation).float()
    out1 = self.l1(observation)
    out2 = self.l2(out1)
    out3 = self.l3(out2)
    qval = self.Qval(out3)
    return qval

и код обратного распространения может быть, например:

self.optimizer = optim.Adam(self.q_net.parameters(), lr=1e-4)
self.optimizer.zero_grad()

state_batch=torch.rand([64,500])
act_batch=np.randi(0,24,[64,1]
act_batch_torch=torch.as_tensor(act_batch)
label_batch = torch.rand([64,500])
Q=self.q_net.forward(state_batch).gather(1, act_batch_torch) # q_net is an instance of the network above
loss = mse_loss(input=Q, target=label_batch.detach())
loss.backward()

self.optimizer.step()

Обратите внимание, что, поскольку логический вывод выполняется намного быстрее при использовании ЦП, я также делаю обратную передачу на ЦП. Я попытался передать сеть на графический процессор, а затем выполнить обратную передачу на графическом процессоре, но это оказалось медленнее.

Есть идеи, почему pyTorch медленнее? Как я могу улучшить скорость для этого типа неглубокой сети?

Это правильный код? Потому что вы делаете optimizer.zero_grad сразу после loss.backward. Разве вы не должны сделать это перед вычислением обратного распространения ошибки?

kvish 02.01.2019 11:49

Да, ты прав. Однако проблема с синхронизацией сохраняется ...

Eli 02.01.2019 19:56

к сожалению, у меня ограниченный опыт работы как с pytorch, так и с обучением с подкреплением. Может быть, объединение всей модели в одну последовательную сеть может помочь ускорить обратное распространение? Вы так же сконструировали в Tensorflow?

kvish 03.01.2019 12:42
Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
1
3
480
0

Другие вопросы по теме