Алгоритм обратного распространения дает плохие результаты

Я пытаюсь решить классическую проблему распознавания рукописных цифр с помощью нейронной сети с прямой связью и обратного распространения, используя набор данных MNIST. Я использую Книга Майкла Нильсена для изучения основ и Видео 3Blue1Brown на YouTube для алгоритма обратного распространения.

Я закончил писать его некоторое время назад и с тех пор занимаюсь отладкой, потому что результаты довольно плохие. В лучшем случае сеть может распознавать ~ 4000/10000 выборок после 1 эпохи, и это число падает только в следующие эпохи, что наводит меня на мысль, что есть какая-то проблема с алгоритмом обратного распространения. Последние несколько дней я тонул в индексном аду, пытаясь отладить это, и не могу понять, в чем проблема, я был бы признателен за любую помощь в указании на нее.

Немного предыстории: 1) Я не использую никакого умножения матриц и никаких внешних фреймворков, а делаю все с помощью циклов for, потому что так я узнал из видео. 2) В отличие от книги, я храню веса и смещения в одном и том же массиве. Смещения для каждого слоя представляют собой столбец в конце матрицы весов для этого слоя.

И наконец для кода это метод Backpropagate класса NeuralNetwork, который вызывается в UpdateMiniBatch, который сам вызывается в SGD:

/// <summary>
/// Returns the partial derivative of the cost function on one sample with respect to every weight in the network.
/// </summary>
public List<double[,]> Backpropagate(ITrainingSample sample)
{
    // Forwards pass
    var (weightedInputs, activations) = GetWeightedInputsAndActivations(sample.Input);

    // The derivative with respect to the activation of the last layer is simple to compute: activation - expectedActivation
    var errors = activations.Last().Select((a, i) => a - sample.Output[i]).ToArray();

    // Backwards pass
    List<double[,]> delCostOverDelWeights = Weights.Select(x => new double[x.GetLength(0), x.GetLength(1)]).ToList();
    List<double[]> delCostOverDelActivations = Weights.Select(x => new double[x.GetLength(0)]).ToList();
    delCostOverDelActivations[delCostOverDelActivations.Count - 1] = errors;

    // Comment notation:
    // Cost function: C
    // Weight connecting the i-th neuron on the (l + 1)-th layer to the j-th neuron on the l-th layer: w[l][i, j]
    // Bias of the i-th neuron on the (l + 1)-th layer: b[l][i]
    // Activation of the i-th neuon on the l-th layer: a[l][i]
    // Weighted input of the i-th neuron on the l-th layer: z[l][i] // which doesn't make sense on layer 0, but is left for index convenience
    // Notice that weights, biases, delCostOverDelWeights and delCostOverDelActivation all start at layer 1 (the 0-th layer is irrelevant to their meanings) while activations and weightedInputs strat at the 0-th layer

    for (int l = Weights.Count - 1; l >= 0; l--)
    {
        //Calculate ∂C/∂w for the current layer:
        for (int i = 0; i < Weights[l].GetLength(0); i++)
            for (int j = 0; j < Weights[l].GetLength(1); j++)
                delCostOverDelWeights[l][i, j] = // ∂C/∂w[l][i, j]
                    delCostOverDelActivations[l][i] * // ∂C/∂a[l + 1][i]
                    SigmoidPrime(weightedInputs[l + 1][i]) * // ∂a[l + 1][i]/∂z[l + 1][i] = ∂(σ(z[l + 1][i]))/∂z[l + 1][i] = σ′(z[l + 1][i])
                    (j < Weights[l].GetLength(1) - 1 ? activations[l][j] : 1); // ∂z[l + 1][i]/∂w[l][i, j] = a[l][j] ||OR|| ∂z[l + 1][i]/∂b[l][i] = 1

        // Calculate ∂C/∂a for the previous layer(a[l]):
        if (l != 0)
            for (int i = 0; i < Weights[l - 1].GetLength(0); i++)
                for (int j = 0; j < Weights[l].GetLength(0); j++)
                    delCostOverDelActivations[l - 1][i] += // ∂C/∂a[l][i] = sum over j:
                        delCostOverDelActivations[l][j] * // ∂C/∂a[l + 1][j]
                        SigmoidPrime(weightedInputs[l + 1][j]) * // ∂a[l + 1][j]/∂z[l + 1][j] = ∂(σ(z[l + 1][j]))/∂z[l + 1][j] = σ′(z[l + 1][j])
                        Weights[l][j, i]; // ∂z[l + 1][j]/∂a[l][i] = w[l][j, i]
    }

    return delCostOverDelWeights;
}

GetWeightedInputsAndActivations:

public (List<double[]>, List<double[]>) GetWeightedInputsAndActivations(double[] input)
{
    List<double[]> activations = new List<double[]>() { input }.Concat(Weights.Select(x => new double[x.GetLength(0)])).ToList();
    List<double[]> weightedInputs = activations.Select(x => new double[x.Length]).ToList();

    for (int l = 0; l < Weights.Count; l++)
        for (int i = 0; i < Weights[l].GetLength(0); i++)
        {
            double value = 0;
            for (int j = 0; j < Weights[l].GetLength(1) - 1; j++)
                value += Weights[l][i, j] * activations[l][j];// weights
            weightedInputs[l + 1][i] = value + Weights[l][i, Weights[l].GetLength(1) - 1];// bias
            activations[l + 1][i] = Sigmoid(weightedInputs[l + 1][i]);
        }

    return (weightedInputs, activations);
}

Всю нейронную сеть, а также все остальное можно найти здесь.

Обновлено: после многих значительных изменений в репо указанная выше ссылка может больше не работать, но, надеюсь, не имеет значения, учитывая ответ. Для полноты картины это функциональная ссылка на измененный репозиторий.

Попробуйте запустить только с двумя нейронами и проверьте промежуточные результаты вручную. Вы также можете проверить СигмоидПрайм, поскольку он отличается от вашей реализации.

mostanes 05.06.2019 00:46

@mostanes Спасибо за комментарий, только что нашел проблему и исправил ее после 5 дней борьбы. Прежде чем я узнал об этом, я сделал тест, в котором вручную рассчитал градиент для 9 весов в сети [2, 3, 1], и он работал нормально. Мой SigmoidPrime — это просто отличный способ написать то же самое, я взял его из книги. Вы можете попробовать проверить это вручную на бумаге.

H. Saleh 05.06.2019 00:54

Ссылка в вашей последней строке не работает.

Matthieu 11.06.2019 22:12
Стоит ли изучать PHP в 2023-2024 годах?
Стоит ли изучать PHP в 2023-2024 годах?
Привет всем, сегодня я хочу высказать свои соображения по поводу вопроса, который я уже много раз получал в своем сообществе: "Стоит ли изучать PHP в...
Поведение ключевого слова "this" в стрелочной функции в сравнении с нормальной функцией
Поведение ключевого слова "this" в стрелочной функции в сравнении с нормальной функцией
В JavaScript одним из самых запутанных понятий является поведение ключевого слова "this" в стрелочной и обычной функциях.
Приемы CSS-макетирования - floats и Flexbox
Приемы CSS-макетирования - floats и Flexbox
Здравствуйте, друзья-студенты! Готовы совершенствовать свои навыки веб-дизайна? Сегодня в нашем путешествии мы рассмотрим приемы CSS-верстки - в...
Тестирование функциональных ngrx-эффектов в Angular 16 с помощью Jest
В системе управления состояниями ngrx, совместимой с Angular 16, появились функциональные эффекты. Это здорово и делает код определенно легче для...
Концепция локализации и ее применение в приложениях React ⚡️
Концепция локализации и ее применение в приложениях React ⚡️
Локализация - это процесс адаптации приложения к различным языкам и культурным требованиям. Это позволяет пользователям получить опыт, соответствующий...
Пользовательский скаляр GraphQL
Пользовательский скаляр GraphQL
Листовые узлы системы типов GraphQL называются скалярами. Достигнув скалярного типа, невозможно спуститься дальше по иерархии типов. Скалярный тип...
8
3
285
1
Перейти к ответу Данный вопрос помечен как решенный

Ответы 1

Ответ принят как подходящий

Фиксированный. Проблема заключалась в следующем: я не разделил входные пиксели на 255. Все остальное работает правильно, и теперь я получаю +9000/10000 в первую эпоху.

Другие вопросы по теме