Я настроил базовую модель UNET. При использовании функции для обучения модели напрямую она оптимизируется. Однако при использовании аналогичного цикла в молнии pytorch с заданным шагом поезда потери не изменяются по сравнению с исходным значением. Я убрал биты zero_grad/backward/step на основе этого урока. Что я делаю неправильно?
# Optimizes well
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
X, y = X.to('cuda',dtype=torch.float), y.to('cuda',dtype=torch.float)
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Using this as a function inside the UNet class, which I feed to pytorch_lightning.Trainer.
# Loss does not update from initial value. Model predictions do not improve.
def training_step(self, batch, batch_idx):
X,y = batch
X, y = X.to(self.device,dtype=torch.float), y.to(self.device,dtype=torch.float)
# Compute prediction error
pred = self.forward(X)
loss = self.loss_fn(pred, y)
self.log("train_loss", loss)
return loss
Эта проблема была вызвана следующей строкой в классе модели:
def configure_optimizers(self):
return super().configure_optimizers()
Один из онлайн-потоков рекомендовал использовать это вместе с training_step и train_dataloader в качестве минимального набора методов для запуска молнии pytorch. Однако на самом деле эта строка мешает оптимизации — возможно, каждый раз загружается одна и та же партия, чтобы потери не улучшались. Простое удаление этого метода устраняет проблему. LightningModule.fit принимает загрузчик данных и использует его для передачи пакетов в training_step.
Было бы лучше, если бы вы вставили воспроизводимый код. например не понятно что такое
self.forward
,self.loss_fn