Когда я натыкаюсь на репозиторий Куанлю на GitHub по обучению cifar-10 с помощью densenet model
, я хочу передать модель графическому процессору, чтобы ускорить процесс обучения. Однако похоже, что входной тензор находится как на процессоре, так и на графическом процессоре. Я подозреваю, что что-то не так в коде пользовательских классов, из-за которого некоторые входные тензоры доступны на процессоре.
Не могли бы вы помочь мне указать на проблему в этом случае? Очень ценю
Для удобства в коде есть функция test()
для проверки правильности компиляции сети. Вот код:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class Bottleneck(nn.Module):
def __init__(self, in_planes, growth_rate):
super(Bottleneck, self).__init__()
self.bn1 = nn.BatchNorm2d(in_planes)
self.conv1 = nn.Conv2d(in_planes, 4*growth_rate, kernel_size=1, bias=False)
self.bn2 = nn.BatchNorm2d(4*growth_rate)
self.conv2 = nn.Conv2d(4*growth_rate, growth_rate, kernel_size=3, padding=1, bias=False)
def forward(self, x):
out = self.conv1(F.relu(self.bn1(x)))
out = self.conv2(F.relu(self.bn2(out)))
out = torch.cat([out,x], 1)
return out
class Transition(nn.Module):
def __init__(self, in_planes, out_planes):
super(Transition, self).__init__()
self.bn = nn.BatchNorm2d(in_planes)
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False)
def forward(self, x):
out = self.conv(F.relu(self.bn(x)))
out = F.avg_pool2d(out, 2)
return out
class DenseNet(nn.Module):
def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10):
super(DenseNet, self).__init__()
self.growth_rate = growth_rate
num_planes = 2*growth_rate
self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False)
self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0])
num_planes += nblocks[0]*growth_rate
out_planes = int(math.floor(num_planes*reduction))
self.trans1 = Transition(num_planes, out_planes)
num_planes = out_planes
self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1])
num_planes += nblocks[1]*growth_rate
out_planes = int(math.floor(num_planes*reduction))
self.trans2 = Transition(num_planes, out_planes)
num_planes = out_planes
self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2])
num_planes += nblocks[2]*growth_rate
out_planes = int(math.floor(num_planes*reduction))
self.trans3 = Transition(num_planes, out_planes)
num_planes = out_planes
self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3])
num_planes += nblocks[3]*growth_rate
self.bn = nn.BatchNorm2d(num_planes)
self.linear = nn.Linear(num_planes, num_classes)
def _make_dense_layers(self, block, in_planes, nblock):
layers = []
for i in range(nblock):
layers.append(block(in_planes, self.growth_rate))
in_planes += self.growth_rate
return nn.Sequential(*layers)
def forward(self, x):
out = self.conv1(x)
out = self.trans1(self.dense1(out))
out = self.trans2(self.dense2(out))
out = self.trans3(self.dense3(out))
out = self.dense4(out)
out = F.avg_pool2d(F.relu(self.bn(out)), 4)
out = out.view(out.size(0), -1)
out = self.linear(out)
return out
def densenet_cifar():
return DenseNet(Bottleneck, [6,12,24,16], growth_rate=12)
def test():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = densenet_cifar().to(device)
x = torch.randn(1,3,32,32)
y = net(x)
print(y)
test()
Хорошего дня!
Вам нужно переместить тензор x
на то же устройство.
Замените x = torch.randn(1,3,32,32)
на x = torch.randn(1,3,32,32).to(device)
Спасибо, Карл. Это работает для этого примера.
Я также заметил, что в другом моем примере, когда я создал экземпляр класса в методе forward()
(например: conv1 = nn.Conv2d()), возникла та же ошибка. Но когда я переместил эту строку экземпляра на init
, она работает отлично!
На самом деле нет смысла заменять строку
x = torch.randn(1,3,32,32)
наx = torch.randn(1,3,32,32).to(device)
в вашем вопросе, как вы это сделали после того, как на него ответил Карл. Ваш вопрос должен показывать код, вызывающий проблему, а не иллюстрировать решение. В таком состоянии твой вопрос больше не имел смысла. Я отменил ваши изменения. Когда Карл сказал: «Измените x = torch.randn(1,3,32,32) на x = torch.randn(1,3,32,32).to(device)», Карл имел в виду ваш реальный код, а не в вашем вопросе.