Веса внимания вверху изображения

h = 16
fig, ax = plt.subplots(ncols=3, nrows=1, figsize=(15, 5))

for i, q_id in enumerate(sorted_indices[0]):
    logit = itm_logit[:, q_id, :]
    prob = torch.nn.functional.softmax(logit, dim=1)
    name = f'{prob[0, 1]:.3f}_query_id_{q_id}'
    
    # Attention map
    attention_map = avg_cross_att[0, q_id, :-1].view(h, h).detach().cpu().numpy()
    
    # Image
    raw_image_resized = raw_image.resize((596, 596))
    
    ax[0].set_title(name)
    ax[0].imshow(attention_map, cmap='viridis')
    ax[0].axis('off')
    
    ax[1].set_title(caption)
    ax[1].imshow(raw_image_resized)
    ax[1].axis('off')
    

    ax[2].set_title(f'Overlay: {name}')
    ax[2].imshow(raw_image_resized)
    ax[2].imshow(attention_map, cmap='viridis', alpha=0.6)  
    ax[2].axis('off')
    

    ax[0].set_aspect('equal')
    ax[1].set_aspect('equal')
    ax[2].set_aspect('equal')
    
    plt.tight_layout()
    plt.savefig(f"./att_maps/{name}.jpg")
    plt.show()
    break

Я пытаюсь наложить веса внимания поверх изображения (по треугольным осям), чтобы я мог видеть, на какой части веса внимания больше сосредоточено внимание.

Однако код, который я разместил, лишь перекрывает вес внимания поверх изображения.

В чем может быть проблема в этом случае?

Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
2
0
59
1
Перейти к ответу Данный вопрос помечен как решенный

Ответы 1

Ответ принят как подходящий

Основная причина этого — разное разрешение изображения и карты внимания. Таким образом, второй вызов imshow уменьшил отображаемую область до крошечного уголка исходного изображения с наложением карты внимания 16x16.

Чтобы это исправить, карту внимания необходимо масштабировать (например, с помощью np.repeat) до разрешения изображения. Вот пример:

import numpy as np
from matplotlib import pyplot as plt
from matplotlib import image

attention_map = np.random.rand(16, 16)
img = image.imread("merlion.jpg")

plt.figure("uneven shapes")
plt.imshow(img)
plt.imshow(attention_map, cmap='viridis', alpha=0.3)

# naive upscaling via np.repeat in both dimensions
attention_map_upscale = np.repeat(np.repeat(attention_map, img.shape[0] // attention_map.shape[0], axis=0),
                                  img.shape[1] // attention_map.shape[1], axis=1)

plt.figure("even shapes")
plt.imshow(img)
plt.imshow(attention_map_upscale, cmap='viridis', alpha=0.3)

plt.show()

Другие вопросы по теме