Я пытаюсь использовать tf.image.ssim
, чтобы получить сходство между двумя изображениями, однако он возвращает ошибку атрибута. Поскольку я просто напрямую использую код TensorFlow, я не вижу способа отладить эту проблему.
import tensorflow as tf
from sklearn import datasets
import matplotlib.pyplot as plt
iris = datasets.load_iris()
x_train, y= tf.keras.datasets.mnist.load_data(
path='mnist.npz'
)
tf.image.ssim(
x_train[0][0], x_train[0][0], 255
)
MNIST возвращает изображение в градациях серого в 2D, SSIM требует, чтобы изображение было в 3D. Так что просто расширьте размеры возвращенного изображения, которое вы хотите сравнить, и примените к нему SSIM.
import numpy as np
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data(
path='mnist.npz'
)
x_train_expanded = np.expand_dims(x_train[0], axis=2)
print(tf.image.ssim(x_train_expanded, x_train_expanded, 255))
Он возвращает следующее:
tf.Tensor(1.0, shape=(), dtype=float32)
Возвращенный тензор содержит значение MS-SSIM для каждого изображения в пакете. Значения находятся в диапазоне [0, 1], и пример возвращает значение 1, указывающее, что оба изображения идентичны.
(60000, 28, 28)
Спасибо, но все равно не запускается. Я попробовал ваш ответ дословно на двух разных компьютерах. Единственное изменение, которое я сделал в начале, это добавление импорта тензорного потока как tf.
Привет, я снова обновил ответ. Это связано с тем, что данные MNIST отображаются в оттенках серого.
Если это ответит на вопрос, пожалуйста, примите мой ответ.
Я ценю вашу помощь, но теперь я получаю следующую ошибку: AttributeError: объект 'numpy.ndarray' не имеет атрибута 'get_shape'
Может быть, вы импортировали библиотеку, которой у меня нет? Вот почему он работает на вашем компьютере, но не на моем.
Я знал, что мне нужно импортировать numpy :) Моя версия tf 2.2.0
Это должно работать, и я обновил ответ выводом.
Не могли бы вы рассказать мне свою версию tensorflow?
2.3.0 все равно должно работать. Вы действительно импортировали набор данных и установили возвращаемые значения так же, как я, и распечатали значение?
Юп, я так и сделал :(
Давайте продолжим обсуждение в чате.
Когда вы печатаете
print(x_train[0].shape)
что вы получаете?