Позвольте мне предварить это, сказав, что я очень новичок в нейронных сетях, и я впервые использую numpy, tensorflow или keras.
Я написал нейронную сеть для распознавания рукописных цифр, используя набор данных MNIST. Я следил за этот учебник от Sentdex и заметил, что он использует print(np.argmax(predictions[0]))
для печати первого индекса из пустого массива прогнозов.
Я попытался запустить программу с заменой этой строки на print(predictions[i])
(у меня было установлено значение 0), но на выходе было не число, а:
[2.1975785e-08 1.8658861e-08 2.8842608e-06 5.7113186e-05 1.2067199e-10
7.2511304e-09 1.6282028e-12 9.9993789e-01 1.3356166e-08 2.0409643e-06]
.
Мой код, чем я запутался:
predictions = model.predict(x_test)
for i in range(10):
plt.imshow(x_test[i])
plt.show()
print("PREDICTION: ", predictions[i])
Я прочитал документацию numpy для функции argmax() и, насколько я понимаю, она принимает x-мерный массив, преобразует его в одномерный массив, а затем возвращает индекс наибольшего значения. В документации Keras для model.predict() указано, что функция возвращает пустой массив прогнозов сетей. Поэтому я не понимаю, почему мы должны использовать argmax() для корректной печати прогноза, потому что, как я понимаю, он имеет совершенно постороннее назначение.
Извините за плохое форматирование кода, я не мог понять, как правильно вставлять многострочные фрагменты кода в свой пост.
Так что на самом деле программа работает только потому, что в массиве numpy предсказания [0] 7-й индекс фактически соотносит угадываемое число 7 как наиболее вероятное.
Если я хорошо понимаю ваш вопрос, ответ довольно прост:
надеюсь понятно ахах
То, что выдает любая классификационная нейронная сеть, представляет собой распределение вероятностей по индексам класса, что означает, что сеть назначает одну вероятность каждому классу. Сумма этих вероятностей равна 1,0. Затем сеть обучается назначать наибольшую вероятность правильному классу, поэтому для восстановления индекса класса из вероятностей вы должны выбрать местоположение (индекс), которое имеет максимальную вероятность. Это делается с помощью операции argmax
.
Этот argmax получает индекс наибольшего значения в
predictions[0]
. Думаю покажет 7, индекс значения 9.99e-01