Сшивание фрагментов изображения вместе

Привет, у меня есть пакет изображений, и мне нужно разделить его на неперекрывающиеся фрагменты и отправить каждый патч через функцию softmax, а затем восстановить исходные изображения. Я могу сделать патчи следующим образом:

@tf.function
def grid_img(img,patch_size=(256, 256), padding = "VALID"):
    p_height, p_width = patch_size
    batch_size, height, width, n_filters = img.shape
    p = tf.image.extract_patches(images=img,
                       sizes=[1,p_height, p_width, 1],
                       strides=[1,p_height, p_width, 1],
                       rates=[1, 1, 1, 1],
                       padding=padding)
    new_shape = list(p.shape[1:-1])+[p_height, p_width, n_filters]
    p = tf.keras.layers.Reshape(new_shape)(p)
    return p

Но я не могу понять, как восстановить исходное изображение партиями. Простое преобразование в исходную партию не работает. Данные не будут расположены в правильном порядке. Буду признателен за любую помощь. спасибо

Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
1
0
47
2
Перейти к ответу Данный вопрос помечен как решенный

Ответы 2

Ответ принят как подходящий

IIUC, вы должны иметь возможность использовать простоtf.reshape для восстановления исходных изображений из пакетов исправлений:

import tensorflow as tf

samples = 5
images = tf.random.normal((samples, 256, 256, 3))

@tf.function
def grid(images):
  img_shape = tf.shape(images)
  batch_size, height, width, n_filters = img_shape[0], img_shape[1], img_shape[2], img_shape[3]

  patches = tf.image.extract_patches(images=images,
                                      sizes=[1, 32, 32, 1],
                                      strides=[1, 32, 32, 1],
                                      rates=[1, 1, 1, 1],
                                      padding='VALID')
  return tf.reshape(tf.nn.softmax(patches), (batch_size, height, width, n_filters))
  
patches = grid(images)
print(patches.shape)
# (5, 256, 256, 3)

Обновление 1: Если вы хотите восстановить изображения в правильном порядке, вы можете рассчитать градиенты tf.image.extract_patches, как показано в этом коде фрагмент. Вот пример:

import tensorflow as tf
import matplotlib.pyplot as plt
import pathlib

@tf.function
def grid(images):
  img_shape = tf.shape(images)
  patches = tf.image.extract_patches(images=images,
                                      sizes=[1, 64, 64, 1],
                                      strides=[1, 64, 64, 1],
                                      rates=[1, 1, 1, 1],
                                      padding='VALID')
  return patches

@tf.function
def extract_patches_inverse(shape, patches):
    _x = tf.zeros(shape)
    _y = grid(_x)
    grad = tf.gradients(_y, _x)[0]
    return tf.gradients(_y, _x, grad_ys=patches)[0] / grad


dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)
batch_size = 32

train_ds = tf.keras.utils.image_dataset_from_directory(
  data_dir,
  seed=123,
  image_size=(512, 512),
  batch_size = batch_size, 
  shuffle= False)

images, _ = next(iter(train_ds.skip(1).take(2)))
patches = grid(images)

shape = (batch_size, 512, 512, 3)
images_reconstructed = extract_patches_inverse(shape, patches)

plt.figure()
f, axarr = plt.subplots(1,2) 
axarr[0].imshow(images[0]/ 255)
axarr[1].imshow(images_reconstructed[0] / 255)

Привет, спасибо за ваш ответ. Но если вы удалите softmax и измените форму, чтобы восстановить исходные изображения, это не сработает np.alltrue(images == patches) дает False. Они разные. Данные расположены неправильно

javid 04.04.2022 14:47

Ага, понятно. Смотрите обновленный ответ. Также см. этот пост: stackoverflow.com/questions/44047753/…. Это работает для вас? @джавид

AloneTogether 04.04.2022 16:19

Ух ты! Спасибо. Не знал, что градиент можно использовать таким образом. Не могли бы вы добавить некоторые пояснения. Я новичок в этом. что происходит в extract_patches_inverse

javid 04.04.2022 16:48

Вы вычисляете обратное tf.image.extract_patches. Пожалуйста, проверьте пост, который я связал.

AloneTogether 04.04.2022 17:18

Грязная работа вокруг этого, о которой я думал, состоит в том, чтобы отслеживать местоположение ячеек после преобразования. Не такой элегантный, как @alonetogether Ответ, но все же может быть полезно поделиться.

import numpy as np 
import tensorflow as tf

@tf.function
def grid(images, grid_size=(32, 32)):
    grid_height, grid_width = grid_size
    patches = tf.image.extract_patches(images=images,
                                      sizes=[1, grid_height, grid_width, 1],
                                      strides=[1, grid_height, grid_width, 1],
                                      rates=[1, 1, 1, 1],
                                      padding='VALID')
    return patches

batch_size, height, width, n_filters = shape = (5, 256, 256, 1)
indices = tf.range(batch_size * height * width * n_filters)
images = tf.reshape(indices, (batch_size, height, width, n_filters ))

patches = grid(images)
transfered_indices = tf.reshape(patches, shape=[-1])
tracked_indices = tf.argsort(transfered_indices) # Indices after transformation, Save this 


images = tf.random.normal(shape)

patches = grid(images)

flatten_patches = tf.reshape(patches, shape=[-1])

reconstructed = tf.reshape(tf.gather(flatten_patches, tracked_indices), shape)

np.alltrue(reconstructed==images) # True

Ваш ответ может быть улучшен с помощью дополнительной вспомогательной информации. Пожалуйста, редактировать добавьте дополнительную информацию, например цитаты или документацию, чтобы другие могли подтвердить правильность вашего ответа. Дополнительную информацию о том, как писать хорошие ответы, можно найти в справочном центре.

Community 04.04.2022 19:38

Другие вопросы по теме