Я использую Keras для построения своей модели, в моей модели есть два входа, тип данных которых — «int32». Затем я использую слой keras Lamba для поиска в матрице встраивания с помощью K.gather (ссылка, индексы). Я вижу, что индексы должны быть тензором int, и я думаю, что мой код соответствует этому, я не знаю, почему насчет ошибки. Мне очень нужна помощь!
input_A = Input(batch_shape=(128,1),name='A_input',dtype='int32')
input_B = Input(batch_shape=(128,1),name='B_input',dtype='int32')
input_A_ = Lambda(lambda x:K.reshape(x,(-1,)))(input_A)
input_B_ = Lambda(lambda x:K.reshape(x, (-1,)))(input_B)
input_A__ = Lambda(lambda x:K.cast(x,dtype='int32'))(input_A_)
input_B__ = Lambda(lambda x:K.cast(x,dtype='int32'))(input_B_)
embedded_text_A = Lambda(lambda x:K.gather(M1,x))(input_A__)
embedded_text_B = Lambda(lambda x:K.gather(M1,x))(input_B__)
По какой-то загадочной причине он будет работать правильно, если поместить K.cast()
внутрь lambda
:
input_A = Input(batch_shape=(128,1), name='A_input', dtype='int32')
input_B = Input(batch_shape=(128,1), name='B_input', dtype='int32')
input_A_ = Lambda(lambda x: K.reshape(x, (-1,)))(input_A)
input_B_ = Lambda(lambda x: K.reshape(x, (-1,)))(input_B)
embedded_text_A = Lambda(lambda x: K.gather(M1, K.cast(x, dtype='int32')))(input_A_)
embedded_text_B = Lambda(lambda x: K.gather(M1, K.cast(x, dtype='int32')))(input_B_)
Следовательно, слой Lambda
делает какое-то странное преобразование dtype внутри.
Я предполагаю, что это какая-то ошибка, и моя гипотеза заключается в том, что неявное преобразование происходит внутри Lambda
__call__
(который унаследован от Layer.__call__
). Я не могу это отследить, но я предполагаю, что ошибка «неявного преобразования» находится где-то в Layer.__call__
, но перед строкой 451, где на самом деле вызывается Lambda.call
.