Недавно я попытался создать свой первый проект с нейронными сетями, и вот что у меня получилось. Я хотел, чтобы он распознавал рукописные номера MNIST. Проблема в том, что когда я запускаю этот код и заставляю его тренироваться примерно 400 000 раз, я получаю примерно 28% точности с тестовыми данными. Это должно быть так? 400 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 000 - это слишком мало для получения лучших результатов, или это связано с тем, что моя нейронная сеть может иметь только один скрытый слой?
Подводя итог короткому вопросу, так и должно быть, или я что-то не так сделал? Ниже много избыточного кода и тому подобного, я просто хотел заставить его работать.
Все при условии, что моя нейронная сеть работает, очевидно.
public static void main(String[] args) {
List<Data> trainData = new ArrayList<>();
List<Data> testData = new ArrayList<>();
byte[] trainLabels;
byte[] trainImages;
byte[] testLabels;
byte[] testImages;
try {
Path tempPath1 = Paths.get("res/train-labels-idx1-ubyte");
trainLabels = Files.readAllBytes(tempPath1);
ByteBuffer bufferLabels = ByteBuffer.wrap(trainLabels);
int magicLabels = bufferLabels.getInt();
int numberOfItems = bufferLabels.getInt();
Path tempPath = Paths.get("res/train-images-idx3-ubyte");
trainImages = Files.readAllBytes(tempPath);
ByteBuffer bufferImages = ByteBuffer.wrap(trainImages);
int magicImages = bufferImages.getInt();
int numberOfImageItems = bufferImages.getInt();
int rows = bufferImages.getInt();
int cols = bufferImages.getInt();
for(int i = 0; i < numberOfItems; i++) {
int t = bufferLabels.get();
double[] target = createTargets(t);
double[] inputs = new double[rows*cols];
for(int j = 0; j < inputs.length; j++) {
inputs[j] = bufferImages.get();
}
Data tobj = new Data(inputs, target);
trainData.add(tobj);
}
tempPath = Paths.get("res/t10k-labels-idx1-ubyte");
testLabels = Files.readAllBytes(tempPath);
ByteBuffer testLabelBuffer = ByteBuffer.wrap(testLabels);
int testMagicLabels = testLabelBuffer.getInt();
int numberOfTestLabels = testLabelBuffer.getInt();
tempPath = Paths.get("res/t10k-images-idx3-ubyte");
testImages = Files.readAllBytes(tempPath);
ByteBuffer testImageBuffer = ByteBuffer.wrap(testImages);
int testMagicImages = testImageBuffer.getInt();
int numberOfTestImages = testImageBuffer.getInt();
int testRows = testImageBuffer.getInt();
int testCols = testImageBuffer.getInt();
for(int i = 0; i < numberOfTestImages; i++) {
double[] target = new double[]{testLabelBuffer.get()};
double[] inputs = new double[testRows*testCols];
for(int j = 0; j < inputs.length; j++) {
inputs[j] = testImageBuffer.get();
}
Data tobj = new Data(inputs, target);
testData.add(tobj);
}
NeuralNetwork neuralNetwork = new NeuralNetwork(784,64,10);
int len = trainData.size();
Random randomGenerator = new Random();
for(int i = 0; i < 400000; i++) {
int randomInt = randomGenerator.nextInt(len);
neuralNetwork.train(trainData.get(randomInt).getInputs(), trainData.get(randomInt).getTargets());
}
float rightAnswers = 0;
for(Data testObj : testData) {
double[] output = neuralNetwork.feedforward(testObj.getInputs());
double[] answer = testObj.getTargets();
}
System.out.println(percentage);
} catch (IOException e) {
e.printStackTrace();
}
}
public static double[] createTargets(int number) {
double[] result = new double[]{0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
result[number] = 1;
return result;
}
Если кому интересно, был баг с моей стороны. При регистрации всего я заметил, что значения входных пикселей варьировались от -255 до 255, а из документации MNIST они должны быть 0-255. Кроме того, мои входные данные не были нормализованы, поэтому некоторые из них были равны 0, а другие — 255. Это то, что я добавил. Надеюсь, я ничего не пропустил. Теперь я получаю ~ 90% точности.
for(int i = 0; i < numberOfTestImages; i++) {
double[] target = new double[]{testLabelBuffer.get()& 0xFF};
double[] inputs = new double[testRows*testCols];
or(int j = 0; j < inputs.length; j++) {
// Normalize input from 0-255 to 0-1
double temp = (testImageBuffer.get() & 0xFF) / 255f;
inputs[j] = temp;
}
Data tobj = new Data(inputs, target);
testData.add(tobj);
}