Обучение изображений на основе патчей и объединение их вероятности с изображения

Во-первых, я реализовал простую VGG16 сеть для классификации изображений.

model = keras.applications.vgg16.VGG16(include_top = False,
                weights = None,
                input_shape = (32,32,3),
                pooling = 'max',
                classes = 10)

Чья входная форма 32 x 32. Теперь я пытаюсь реализовать patch-based neural network. Основная идея состоит в том, чтобы из входного изображения извлечь 4 фрагмента изображения, подобные этому изображению,

и, наконец, обучите извлеченное изображение патча (поскольку это входная форма нашей модели), объедините их четыре выходных вероятности и найдите окончательный выходной результат (используя нормализацию и argmax). Так,

Как мне это сделать?

Заранее спасибо за вашу помощь.

Примечание:

Я предполагаю, что с помощью resizing to 32 x 32 это возможно.

Моя простая реализация классификации VGG здесь, в Colab.

Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
0
0
1 098
1
Перейти к ответу Данный вопрос помечен как решенный

Ответы 1

Ответ принят как подходящий

Я использовал набор данных MNIST, чтобы получить каждое изображение в виде 4 патчей с tf.image.extract_patches, которые впоследствии передаются как пакет:

import tensorflow as tf
from tensorflow import keras as K
from tensorflow.keras.layers import Conv2D, Flatten, Dense, MaxPooling2D, Dropout
from tensorflow import nn as nn
from functools import partial
import matplotlib.pyplot as plt

(xtrain, ytrain), (xtest, ytest) = tf.keras.datasets.mnist.load_data()

train = tf.data.Dataset.from_tensor_slices((xtrain, ytrain))
test = tf.data.Dataset.from_tensor_slices((xtest, ytest))

patch_s = 18
stride = xtrain.shape[1] - patch_s

get_patches = lambda x, y: (tf.reshape(
    tf.image.extract_patches(
        images=tf.expand_dims(x[..., None], 0),
        sizes=[1, patch_s, patch_s, 1],
        strides=[1, stride, stride, 1],
        rates=[1, 1, 1, 1],
        padding='VALID'), (4, patch_s, patch_s, 1)), y)

train = train.map(get_patches)
test = test.map(get_patches)

fig = plt.figure()
plt.subplots_adjust(wspace=.1, hspace=.2)
images, labels = next(iter(train))
for index, image in enumerate(images):
    ax = plt.subplot(2, 2, index + 1)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.imshow(image)
plt.show()

Затем в тренировочном цикле я получаю потери для каждого из этих 4 выходов:

def compute_loss(model, x, y, training):
  out = model(x=x, training=training)
  repeated_y = tf.repeat(tf.expand_dims(y, 0), repeats=4, axis=0)
  loss = loss_object(y_true=repeated_y, y_pred=out, from_logits=True)
  loss = tf.reduce_mean(loss, axis=0)
  return loss

Затем я уменьшаю среднее значение оси 0, чтобы объединить все вероятности. Вот полный рабочий код:

import tensorflow as tf
from tensorflow import keras as K
from tensorflow.keras.layers import Conv2D, Flatten, Dense, MaxPooling2D, Dropout
from tensorflow import nn as nn
from functools import partial
import matplotlib.pyplot as plt

(xtrain, ytrain), (xtest, ytest) = tf.keras.datasets.mnist.load_data()

train = tf.data.Dataset.from_tensor_slices((xtrain, ytrain))
test = tf.data.Dataset.from_tensor_slices((xtest, ytest))

patch_s = 18
stride = xtrain.shape[1] - patch_s

get_patches = lambda x, y: (tf.reshape(
    tf.image.extract_patches(
        images=tf.expand_dims(x[..., None], 0),
        sizes=[1, patch_s, patch_s, 1],
        strides=[1, stride, stride, 1],
        rates=[1, 1, 1, 1],
        padding='VALID'), (4, patch_s, patch_s, 1)), y)

train = train.map(get_patches)
test = test.map(get_patches)

fig = plt.figure()
plt.subplots_adjust(wspace=.1, hspace=.2)
images, labels = next(iter(train))
for index, image in enumerate(images):
    ax = plt.subplot(2, 2, index + 1)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.imshow(image)
plt.show()

def prepare(inputs, targets):
    inputs = tf.divide(x=inputs, y=255)
    targets = tf.one_hot(indices=targets, depth=10)
    return inputs, targets

train = train.take(10_000).map(prepare)
test = test.take(10_00).map(prepare)

class MyCNN(K.Model):
    def __init__(self):
        super(MyCNN, self).__init__()
        Conv = partial(Conv2D, kernel_size=(3, 3), activation=nn.relu)
        MaxPool = partial(MaxPooling2D, pool_size=(2, 2))

        self.conv1 = Conv(filters=16)
        self.maxp1 = MaxPool()
        self.conv2 = Conv(filters=32)
        self.maxp2 = MaxPool()
        self.conv3 = Conv(filters=64)
        self.maxp3 = MaxPool()
        self.flatt = Flatten()
        self.dens1 = Dense(64, activation=nn.relu)
        self.drop1 = Dropout(.5)
        self.dens2 = Dense(10, activation=nn.softmax)

    def call(self, inputs, training=None, **kwargs):
        x = self.conv1(inputs)
        x = self.maxp1(x)
        x = self.conv2(x)
        x = self.maxp2(x)
        x = self.conv3(x)
        x = self.maxp3(x)
        x = self.flatt(x)
        x = self.dens1(x)
        x = self.drop1(x)
        x = self.dens2(x)
        return x

model = MyCNN()

loss_object = tf.losses.categorical_crossentropy

def compute_loss(model, x, y, training):
  out = model(inputs=x, training=training)
  repeated_y = tf.repeat(tf.expand_dims(y, 0), repeats=4, axis=0)
  loss = loss_object(y_true=repeated_y, y_pred=out, from_logits=True)
  loss = tf.reduce_mean(loss, axis=0)
  return loss

def get_grad(model, x, y):
    with tf.GradientTape() as tape:
        loss = compute_loss(model, x, y, training=False)
    return loss, tape.gradient(loss, model.trainable_variables)

optimizer = tf.optimizers.Adam()

verbose = "Epoch {:2d}" \
          " Loss: {:.3f} Acc: {:.3%} TLoss: {:.3f} TAcc: {:.3%}"

for epoch in range(1, 10 + 1):
    train_loss = tf.metrics.Mean()
    train_acc = tf.metrics.CategoricalAccuracy()
    test_loss = tf.metrics.Mean()
    test_acc = tf.metrics.CategoricalAccuracy()

    for x, y in train:
        loss_value, grads = get_grad(model, x, y)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        train_loss.update_state(loss_value)
        train_acc.update_state(y, model(x, training=True))

    for x, y in test:
        loss_value, _ = get_grad(model, x, y)
        test_loss.update_state(loss_value)
        test_acc.update_state(y, model(x, training=False))

    print(verbose.format(epoch,
                         train_loss.result(),
                         train_acc.result(),
                         test_loss.result(),
                         test_acc.result()))

Спойлер: с такими маленькими патчами это не очень хорошо. Делайте патчи больше 18/28 для лучшей производительности.

Другие вопросы по теме

Как использовать Python mitm для захвата запросов и воспроизведения по запросу через флягу
Модули не могут загружаться в новый проект Django
Подпроцесс Python.Popen + ffmpeg прерывает ввод терминала
Сетка осей не отображается при явном использовании стиля, содержащего сетку, с помощью matplotlib.pyplot.style.use('seaborn-whitegrid')
Заполните новый столбец в фрейме данных pandas двоичным значением, если значение другого столбца находится в списке или наборе
Проверка дублирующегося адреса электронной почты, вызывающего ошибку неверного запроса
Как получить текст класса css с помощью Selenium
Как я могу обновить окно с изображениями в Python Kivy?
Есть ли способ извлечь часть фрейма данных с помощью pandas (или другой библиотеки) в Python?
Фильтр данных Dataframe на основе совпадающих значений в столбце и временной метки минимальных/максимальных значений тех значений, которые совпали