Я пытаюсь оптимизировать следующие параметры, сохраняя параметр L всегда ниже треугольника с положительной диагональю, а параметр шума всегда только с положительной диагональю во время оптимизации, но они неправильно обновляются при прямом проходе. Думаю, я что-то не так делаю с механизмом автоградации. Любая помощь приветствуется. Вот образец фрагмента, иллюстрирующий проблему - использование вымышленных вычислений - с использованием PyTorch 0.40.
import torch
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from torch.nn import ParameterList
class Model(torch.nn.Module):
def __init__(self, dim):
"""
Constructor.
"""
super(Model, self).__init__()
self.noise_vector = Parameter(torch.tensor(torch.zeros(D).cuda(), requires_grad=True))
self.noise = Parameter(torch.tensor(torch.diag(torch.exp(self.noise_vector.data)).cuda(), requires_grad=True))
self.L_chol_cov_theta = Parameter(torch.tensor(torch.randn(dim, dim).cuda(), requires_grad=True))
self.log_diag_L_chol_cov_theta = Parameter(torch.tensor(torch.randn(dim).cuda(), requires_grad=True))
self.L = Parameter(torch.tensor(torch.randn(dim, dim).cuda(), requires_grad=True))
self.L_chol_cov_theta.data = torch.tril(self.L_chol_cov_theta.data)
self.L_chol_cov_theta.data -= torch.diag(torch.diag(self.L_chol_cov_theta.data))
self.L.data = self.L_chol_cov_theta.data + torch.diag(torch.exp(self.log_diag_L_chol_cov_theta.data))
def forward(self):
# update parameters
self.L_chol_cov_theta.data -= torch.diag(torch.diag(self.L_chol_cov_theta.data))
self.L.data = self.L_chol_cov_theta.data + torch.diag(torch.exp(self.log_diag_L_chol_cov_theta.data))
self.noise.data = torch.diag(torch.exp(self.noise_vector.data))
return torch.mm(self.L, self.noise_vector.view(-1,-1))
model = Model(5)
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)
for epoch in range(10):
optimizer.zero_grad()
forward_pass_something = model()
loss = calc_likelihood(samples, a_ground_truth) # calc a custom loss
loss.backward()
optimizer.step()
@ManojAcharya, потому что, если я работаю напрямую с переменными, я получаю ошибки операции на месте.
@gelazari, тогда вы должны знать, почему вы получаете оперативные ошибки на месте и как их решить. прочтите внимательно: pytorch.org/docs/master/notes/…






Почему вы работаете с данными переменных, а не с переменной напрямую? Если вы работаете с данными, я не думаю, что он будет делать обратную передачу правильно.