Настройка функциональной модели в Керасе

Я просто возился с Keras для развлечения (для продолжения обучения), и у меня возникли некоторые проблемы с указанием структуры данных в CNN.

Tutorial: https://github.com/bhaveshoswal/CNN-text-classification-keras
Data: https://catalog.data.gov/dataset/consumer-complaint-database/resource/2f297213-7198-4be1-af1e-2d2623e7f6e9

Итак, у меня настроена обработка данных, и все выглядит правильно.

from keras.layers import Input, Dense, Embedding, Conv2D, MaxPool2D
from keras.layers import Reshape, Flatten, Dropout, Concatenate
from keras.callbacks import ModelCheckpoint
from keras.optimizers import Adam
from keras.models import Model
from sklearn.model_selection import train_test_split
from helper_functions import load_data

print('Loading data')
x, y, vocabulary, vocabulary_inv = load_data()

# x.shape -> (50000, 5371)
# y.shape -> (50000, 10)
# len(vocabulary) -> 50111
# len(vocabulary_inv) -> 50111

X_train, X_test, y_train, y_test = train_test_split( x, y, test_size=0.2, random_state=42)

# X_train.shape -> (40000, 5371)
# y_train.shape -> (40000, 18)
# X_test.shape -> (10000, 5371)
# y_test.shape -> (10000, 18)

Когда я пытаюсь подогнать модель, я получаю ошибку, связанную с размерами набора данных.

sequence_length = x.shape[1] # 56
vocabulary_size = len(vocabulary_inv) # 18765
embedding_dim = 256
filter_sizes = [3,4,5]
num_filters = 512
drop = 0.5

epochs = 100
batch_size = 30

# this returns a tensor
print("Creating Model...")

inputs = Input(shape=(sequence_length,), dtype='int32')
embedding = Embedding(input_dim=vocabulary_size, output_dim=embedding_dim, input_length=sequence_length)(inputs)
reshape = Reshape((sequence_length,embedding_dim,1))(embedding)

conv_0 = Conv2D(num_filters, kernel_size=(filter_sizes[0], embedding_dim), padding='valid', kernel_initializer='normal', activation='relu')(reshape)
conv_1 = Conv2D(num_filters, kernel_size=(filter_sizes[1], embedding_dim), padding='valid', kernel_initializer='normal', activation='relu')(reshape)
conv_2 = Conv2D(num_filters, kernel_size=(filter_sizes[2], embedding_dim), padding='valid', kernel_initializer='normal', activation='relu')(reshape)

maxpool_0 = MaxPool2D(pool_size=(sequence_length - filter_sizes[0] + 1, 1), strides=(1,1), padding='valid')(conv_0)
maxpool_1 = MaxPool2D(pool_size=(sequence_length - filter_sizes[1] + 1, 1), strides=(1,1), padding='valid')(conv_1)
maxpool_2 = MaxPool2D(pool_size=(sequence_length - filter_sizes[2] + 1, 1), strides=(1,1), padding='valid')(conv_2)

concatenated_tensor = Concatenate(axis=1)([maxpool_0, maxpool_1, maxpool_2])
flatten = Flatten()(concatenated_tensor)
dropout = Dropout(drop)(flatten)
output = Dense(units=18, activation='softmax')(dropout)

# this creates a model that includes
model = Model(inputs=inputs, outputs=output)

checkpoint = ModelCheckpoint('weights.{epoch:03d}-{val_acc:.4f}.hdf5', monitor='val_acc', verbose=1, save_best_only=True, mode='auto')
adam = Adam(lr=1e-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)

#model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

print("Traning Model...")
model.fit(X_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1, callbacks=[checkpoint], validation_data=(X_test, y_test))  # starts training

Вот сообщение об ошибке:

Traning Model...
Traceback (most recent call last):

  File "<ipython-input-294-835f1e289b39>", line 41, in <module>
    model.fit(X_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1, callbacks=[checkpoint], validation_data=(X_test, y_test))  # starts training

  File "/Users/abrahammathew/anaconda3/lib/python3.6/site-packages/keras/engine/training.py", line 1630, in fit
    batch_size=batch_size)

  File "/Users/abrahammathew/anaconda3/lib/python3.6/site-packages/keras/engine/training.py", line 1480, in _standardize_user_data
    exception_prefix='target')

  File "/Users/abrahammathew/anaconda3/lib/python3.6/site-packages/keras/engine/training.py", line 123, in _standardize_input_data
    str(data_shape))

ValueError: Error when checking target: expected dense_11 to have shape (1,) but got array with shape (18,)

Учитывая эту ошибку, как должны выглядеть измерения данных при создании моделей Keras с помощью функционального API.

@ DanielMöller исходный код не ... Думаю, я должен попробовать его с 2, что было исходным значением

AGUY 28.05.2018 19:11

Думаю, я ошибся, см. Ответ @nuric.

Daniel Möller 28.05.2018 19:18

В любом случае, 18 - правильное число.

Daniel Möller 28.05.2018 19:19
Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
1
3
58
1
Перейти к ответу Данный вопрос помечен как решенный

Ответы 1

Ответ принят как подходящий

У вас есть sparse_categorical_crossentropy, который ожидает только целочисленные метки классов, тогда как вы уже даете закодированные версии (18,). Таким образом, вам необходимо заменить loss='categorical_crossentropy', чтобы устранить проблему.

Другие вопросы по теме