Tensorflow: странное поведение с разделением строки в tf.data.Dataset

Я использую API tf.data.Dataset в Tensorflow. У меня есть 2 массива numpy, где data - 2-й, а labels - 1-й. Я создал Dataset вот так:

dataset = tf.data.Dataset.from_tensor_slices((data, labels))
val_dataset = dataset.map(lambda x, y: ({'reviews': x}, y))

У меня есть функция предварительной обработки, которую я хотел бы использовать, она выглядит так:

def preprocess(x, y):
    # split on whitespace
    x['reviews'] = tf.string_split(x['reviews'])
    return x, y

Я пытаюсь использовать map вот так:

dataset = dataset.map(preprocess)

но я возвращаюсь:

ValueError: Shape must be rank 1 but is rank 0 for 'StringSplit' (op: 'StringSplit') with input shapes: [], [].

Я погуглил и обнаружил, что кто-то предложил этот подход в функции предварительной обработки:

x['reviews'] = tf.string_split([x['reviews']])

Но мне непонятно, зачем я это сделал. Это не ошибка, как раньше, но форма моих данных неверна. Например, вот что я вижу для первого элемента в моем dataset:

({'sequence': array([[ 6391,  3352, 10236,   244,  1362,   244,  9350,  7649,  6391,
         6324,  6063,  3620,   244,  8153,  6542, 10056,  7303,  1955,
         1362,  6194, 10250,  6391,   550,   244,  7577,   850,  3620,
         5807, 10325,  1362,  6542,   595,  9060,  9052,  9459,   351,
         4676,  9354,  7648,  3082,  7694,  8497, 10703,  1610,  9454,
        10236,   244,  7965,  8018,  9392,  6391,  6063,  2878,  1318,
         3169,  8198,  9354,  4131,  3620,  3082,  3352,  9052,  8018,
         7527,  3419,  1907,  8835,   796,   244,  8957,  4325,  8171,
         9454,  7602,  4435,  7648,  3169,  2083,  9454,  4789,  9620,
         9261,   556,  3524,  8497,  9174,  8299,  5871,  9052,  2888,
         9846,  1610,  1362,  4930,  2150,  1362,  8018,  3867,   341,
         7694,  8497,  6063,  3620,   244,  5807,  6089,  3169,  6350,
         1174,  7694,   949,  1292,   244,  9052,  9440,  3690,  1362,
         1907,  9011,  4156,  6081,   145,  1174,  7694,  9986,   949,
         1292,  3169,  1455,  6372,  9760,  5013,  3169,  1455,  5942,
         4365,  1362,  1907,   244,  5813,   244,  7994,  3525,  3550,
         7509,  6372,  9760,  7860,  9052,  2888,  7694,  8497,  1610,
         1316,   326,  1174,  3039,  3524,  9703,  3620,  6612,  1455,
          556,  9011,  3169,  1927,  9052,   409,  4059,  9354,   700,
         5503,  3550,  9052,  2083,  1963,   595,  3169,  7715, 10236,
         9442,  1174, 10087,  3169,  5312,  7474,  9052,  3525,  3169,
         5826,  7885,  6944,  7130,  5821,  2878,  7184,   153,  3169,
         8633,  8574,  1283,   606,  7902,  6110,  3082,  6406,  3169,
         8316,  6126,   688, 10236,  9440,  3082, 10584,  2143,  5460,
         5809,  1362,  2878, 10439,  3419,  1907,  4598,  4156, 10239,
         1450,  5514,  5010,  9350,   244,   651]])}, 0)

Таким образом, значение словаря представляет собой двумерный массив, тогда как оно должно быть только 1-мерным. Где я ошибаюсь?

Спасибо!

Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
1
0
92
1

Ответы 1

Непринятие скаляров кажется ограничением tf.string_split. Пожалуйста, сообщите о проблеме на https://github.com/tensorflow/tensorflow/issues

Что касается обходных путей, предложение обернуть список - это хорошо, но вам также нужно сжать его после разделения, чтобы у вас был вектор компонентов, а не двумерный тензор.

import tensorflow as tf
tf.enable_eager_execution()
scalar = tf.constant('ab c de')
print(scalar.shape)  # () scalar
vector = scalar[None]
print(vector.shape)  # (1,) vector
output = tf.sparse.to_dense(tf.string_split(vector), default_value='')
print(output)  # tf.Tensor([[b'ab' b'c' b'de']], shape=(1, 3), dtype=string)
squeezed = tf.squeeze(output, axis=0)
print(squeezed.shape)  # (3,) vector
print(squeezed)  # tf.Tensor([b'ab' b'c' b'de'], shape=(3,), dtype=string)

Другие вопросы по теме