Пакет настройки данных MNIST

Я нахожусь на этапе обучения модели. Однако, когда я применяю код из учебника: batch_x, batch_y = mnist.train.next_batch(50). Это показывает, что в модели TensorFlow нет атрибута «поезд». Я знаю, что это устаревший код, и я попытался перейти на новую версию TensorFlow. Однако я не смог найти соответствующий код, который может делать то же самое, что и приведенная выше строка кода. Бьюсь об заклад, есть способ, но я не мог придумать ни одного решения.

Я нашел метод, который попросил меня использовать tf.data.Dataset.batch(batch_size). Я пробовал следующим образом, но ни один из них не работает.

a. batch_x, batch_y = mnist.train.next_batch(50)

b. batch_x, batch_y =  tf.data.Dataset.batch(batch_size)

c. batch_x, batch_y =  tf.data.Dataset.batch(50)

d. batch_x, batch_y = mnist.batch(50)

with tf.Session() as sess:

  #FIrst, run vars_initializer to initialize all variables
  sess.run(vars_initializer)

  for i in range(steps):

    #Each batch: 50 images
    batch_x, batch_y = mnist.train.next_batch(50)

    #Train the model
    #Dropout keep_prob (% to keep): 0.5 --> 50% will be dropped out
    sess.run(cnn_trainer, feed_dict = {x: batch_x, y_true: batch_y, hold_prob: 0.5})

    #Test the model: at each 100th step
    #Run this block of code for each 100 times of training, each time run a batch
    if i % 100 == 0:
      print('ON STEP: {}'.format(i))
      print('ACCURACY: ')

      #Compare to find matches of y_pred and y_true
      matches = tf.equal(tf.argmax(y_pred, 1), tf.argmax(y_true, 1))

      #Cast the matches from integers to tf.float32
      #Calculate the accuracy using the mean of matches
      acc = tf.reduce_mean(tf.cast(matches, tf.float32))

      #Test the model at each 100th step
      #Using test dataset
      #Dropout: NONE because of test, not training. 
      test_accuracy = sess.run(acc, feed_dict = {x:mnist.test.images, y_true:mnist.test.labels, hold_prob:1.0})


      print(test_accuracy)
      print('\n')

Вы хотите получить партии из набора данных MNIST?

Shubham Panchal 09.04.2019 09:33

@ Шубхам Панчал, да, я пытаюсь получить batch_x и batch_y.

Maryg 09.04.2019 09:38

Пожалуйста, отформатируйте код правильно.

Mohan Radhakrishnan 09.04.2019 09:54
Стоит ли изучать PHP в 2023-2024 годах?
Стоит ли изучать PHP в 2023-2024 годах?
Привет всем, сегодня я хочу высказать свои соображения по поводу вопроса, который я уже много раз получал в своем сообществе: "Стоит ли изучать PHP в...
Поведение ключевого слова "this" в стрелочной функции в сравнении с нормальной функцией
Поведение ключевого слова "this" в стрелочной функции в сравнении с нормальной функцией
В JavaScript одним из самых запутанных понятий является поведение ключевого слова "this" в стрелочной и обычной функциях.
Приемы CSS-макетирования - floats и Flexbox
Приемы CSS-макетирования - floats и Flexbox
Здравствуйте, друзья-студенты! Готовы совершенствовать свои навыки веб-дизайна? Сегодня в нашем путешествии мы рассмотрим приемы CSS-верстки - в...
Тестирование функциональных ngrx-эффектов в Angular 16 с помощью Jest
В системе управления состояниями ngrx, совместимой с Angular 16, появились функциональные эффекты. Это здорово и делает код определенно легче для...
Концепция локализации и ее применение в приложениях React ⚡️
Концепция локализации и ее применение в приложениях React ⚡️
Локализация - это процесс адаптации приложения к различным языкам и культурным требованиям. Это позволяет пользователям получить опыт, соответствующий...
Пользовательский скаляр GraphQL
Пользовательский скаляр GraphQL
Листовые узлы системы типов GraphQL называются скалярами. Достигнув скалярного типа, невозможно спуститься дальше по иерархии типов. Скалярный тип...
1
3
1 705
2
Перейти к ответу Данный вопрос помечен как решенный

Ответы 2

Это использует TensorFlow 1.11.0 и Keras и предназначено для демонстрации того, как использовать batch. Вы должны адаптировать его к вашим потребностям.

import tensorflow as tf
from tensorflow import keras as k


(x_train, y_train), (X_test, Y_test) = tf.keras.datasets.mnist.load_data()
X_train = x_train.reshape(x_train.shape[0], 28, 28,1)
y_train = tf.keras.utils.to_categorical(y_train,10)
X_test = X_test.reshape(X_test.shape[0], 28, 28,1)
Y_test = tf.keras.utils.to_categorical(Y_test,10)


train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
train_dataset = train_dataset.batch(32)

test_dataset = tf.data.Dataset.from_tensor_slices((X_test, Y_test))
test_dataset = test_dataset.batch(32)


model = tf.keras.models.Sequential([
    tf.keras.layers.Convolution2D(32, (2, 2), activation='relu', input_shape=(28, 28,1)),
    tf.keras.layers.MaxPool2D(pool_size=2),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128),
    tf.keras.layers.Activation('relu'),
    tf.keras.layers.Dropout(0.5),
     tf.keras.layers.Dense(10, activation='softmax')
])

tbCallback = [
    k.callbacks.TensorBoard(
        log_dir = "D:/TensorBoard", histogram_freq=1, write_graph=True, write_images=True
    )
]


model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(train_dataset, epochs = 10, steps_per_epoch = 30,validation_data=test_dataset,validation_steps=1, callbacks=tbCallback)

Спасибо Мохан, попробовал поменять свой керас на более старую версию. и это работает! Большое спасибо!

Maryg 09.04.2019 22:06
Ответ принят как подходящий

Вы можете использовать tf.keras.datasets.mnist.load_data. Он возвращает кортеж массивов Numpy: (x_train, y_train), (x_test, y_test).

После этого вам нужно создать объект набора данных с помощью Dataset API. Это создаст обучающий набор данных. Таким же образом можно создать тестовый набор данных.

train, test = tf.keras.datasets.mnist.load_data()
dataset = tf.data.Dataset.from_tensor_slices((train[0], train[1]))

Затем, чтобы создать пакет, вам нужно применить к нему пакетную функцию.

dataset = dataset.batch(1)

Чтобы вывести его содержимое или использовать в обучении, вам нужно создать итератор. Код ниже создает наиболее распространенный итератор и выводит элемент batch_size в данном случае 1.

iterator = dataset.make_one_shot_iterator()
with tf.Session() as sess:
    print(sess.run(iterator.get_next())

Пожалуйста, прочтите https://www.tensorflow.org/guide/datasets

Спасибо Шарки! это имеет смысл для меня! Я также попытался понизить версию keras до 1.12.0 или ниже, и это тоже работает! Действительно ценю это!

Maryg 09.04.2019 22:07

Другие вопросы по теме