Pytorch: нет эффекта обучения после глубокого копирования

Я попытался сделать копию нейронной сети в pytorch и впоследствии обучить скопированную сеть, но обучение, похоже, не меняет веса в сети после копирования. Этот пост предполагает, что deepcopy — это удобный способ сделать копию нейронной сети, поэтому я попытался использовать это в своем коде.

Код ниже работает просто отлично и показывает, что веса и точность сети отличаются после обучения и до обучения. Однако, когда я переключаюсь так, чтобы network_cp=deepcopy(network) и optimizer_cp=deepcopy(optimizer), точность и вес до и после тренировки были одинаковыми.

# torch settings
torch.backends.cudnn.enabled = True 
device = torch.device("cpu")

# training settings
learning_rate = 0.01
momentum = 0.5
batch_size_train = 64
batch_size_test = 1000

# get MNIST data set
train_loader, test_loader = load_mnist(batch_size_train=batch_size_train,
    batch_size_test=batch_size_test)

# make a network
network = Net()
optimizer = optim.SGD(network.parameters(), lr=learning_rate,
    momentum=momentum)
network.to(device)

# train network
train(network, optimizer, train_loader, device)

# copy network
network_cp = network
#network_cp = deepcopy(network)
optimizer_cp = optimizer
#optimizer_cp = deepcopy(optimizer)

# get edge weights and accuracy of the copied network
acc1 = float(test(network_cp, optimizer_cp, test_loader, device))
weights1 = np.array(get_edge_weights(network_cp))

# train copied network
train(network_cp, optimizer_cp, train_loader, device)

# get edge weights and accuracy of the copied network after training
acc2 = float(test(network_cp, optimizer_cp, test_loader, device))
weights2 = np.array(get_edge_weights(network_cp))

# compare edge weights and accuracy of copied network before and after training
print('accuracy', acc1, acc2)
print('abs diff of weights for net1 and net2', np.sum(np.abs(weights1-weights2)))

Чтобы запустить приведенный выше код, включите следующие импорты и определения функций:

import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as tnn
import torch.nn.functional as tnf
from copy import deepcopy
import numpy as np

def load_mnist(batch_size_train = 64, batch_size_test = 1000):
    train_loader = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST('temp/', #'/data/users/alice/pytorch_training_files/',
                                   train=True, download=True,
                                   transform=torchvision.transforms.Compose([
                                       torchvision.transforms.ToTensor(),
                                       torchvision.transforms.Normalize(
                                           (0.1307,), (0.3081,))
                                       ])),
        batch_size=batch_size_train, shuffle=True)

    test_loader = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST('temp/', #'/data/users/alice/pytorch_training_files/',
                                   train=False, download=True,
                                   transform=torchvision.transforms.Compose([
                                       torchvision.transforms.ToTensor(),
                                       torchvision.transforms.Normalize(
                                           (0.1307,), (0.3081,))
                                     ])),
        batch_size=batch_size_test, shuffle=True)

    return(train_loader, test_loader)

def train(network, optimizer, train_loader, device, n_epochs=5):
    network.train()
    for epoch in range(1, n_epochs + 1):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = network(data)
            loss = tnf.nll_loss(output, target)
            loss.backward()
            optimizer.step()

def test(network, optimizer, test_loader, device):
    network.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = network(data)
            test_loss += tnf.nll_loss(output, target, size_average=False).item()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()
    test_loss /= len(test_loader.dataset)
    print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

    return(float(correct)/float(len(test_loader.dataset)))

def get_edge_weights(network):
    layers = [module for module in network.modules()][1:]
    output = np.zeros(1)
    for j, layer in enumerate(layers):
        weights = list(layer.parameters())[0]
        weights_arr = weights.detach().numpy()
        weights_arr = weights_arr.flatten()
        output = np.concatenate((output,weights_arr))
    return output[1:]

class Net(tnn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 =tnn.Linear(784,264)
        self.fc2 = tnn.Linear(264,10)

    def forward(self, x):
        x = tnf.relu(self.fc1(x.view(-1,784)))
        x = tnf.relu(self.fc2(x))
        return tnf.log_softmax(x)
Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
2
0
2 515
1
Перейти к ответу Данный вопрос помечен как решенный

Ответы 1

Ответ принят как подходящий

После optimizer_cp = deepcopy(optimizer)optimizer_cp все еще хочет оптимизировать параметры старой модели (как определено optimizer = optim.SGD(network.parameters(), lr=learning_rate, momentum=momentum)).

После глубокого копирования модели оптимизатору необходимо указать оптимизировать параметры этой новой модели:

optimizer_cp = optim.SGD(network_cp.parameters(), lr=learning_rate, momentum=momentum)

сбросит ли это параметры оптимизатора? Если да, то как я могу их сохранить?

Moltres 26.06.2021 17:27

Другие вопросы по теме