Визуализация встраивания узлов последнего слоя модели в факел геометрического

Я делаю свой первый проект графовой сверточной нейронной сети с torch_geometric. Я хочу визуализировать вложения узлов последнего слоя моей модели и не знаю, как мне это получить.

Я обучил свою модель набору данных CiteSeer. Вы можете получить полный набор данных так же просто, как это:

from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
dataset = Planetoid(root = "data/Planetoid", name='CiteSeer', transform=NormalizeFeatures())

Моя модель представляет собой простую двухслойную модель:

class GraphClassifier(torch.nn.Module):
    def __init__(self, dataset, hidden_dim):
        super(GraphClassifier, self).__init__()
        self.conv1 = GCNConv(dataset.num_features, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, dataset.num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        return F.log_softmax(x, dim=1) 

Если вы распечатаете мою модель, вы получите это:

model = GraphClassifier(dataset, 64)
print(model)

>>>
GraphClassifier(
  (conv1): GCNConv(3703, 64)
  (conv2): GCNConv(64, 6)
)

Моя модель успешно обучена. Я только хочу визуализировать его вложения узлов последнего уровня. Чтобы визуализировать, что у меня есть эта функция для использования:

%matplotlib inline
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import torch
# emb: (nNodes, hidden_dim)
# node_type: (nNodes,). Entries are torch.int64 ranged from 0 to num_class - 1

def visualize(emb: torch.tensor, node_type: torch.tensor):
  z = TSNE(n_components=2).fit_transform(emb.detach().cpu().numpy())
  plt.figure(figsize=(10,10))
  plt.scatter(z[:, 0], z[:, 1], s=70, c=node_type, cmap = "Set2")
  plt.show()

Я не знаю, как мне извлечь emb и node_type из моей модели, чтобы передать их функции visualize. emb — последний слой вложений узлов модели. Как я могу получить их от моей модели?

Мутабельность и переработка объектов в Python
Мутабельность и переработка объектов в Python
Объекты являются основной конструкцией любого языка ООП, и каждый язык определяет свой собственный синтаксис для их создания, обновления и...
Другой маршрут в Flask Python
Другой маршрут в Flask Python
Flask - это фреймворк, который поддерживает веб-приложения. В этой статье я покажу, как мы можем использовать @app .route в flask, чтобы иметь другую...
14 Задание: Типы данных и структуры данных Python для DevOps
14 Задание: Типы данных и структуры данных Python для DevOps
Проверить тип данных используемой переменной, мы можем просто написать: your_variable=100
Python PyPDF2 - запись метаданных PDF
Python PyPDF2 - запись метаданных PDF
Python скрипт, который будет записывать метаданные в PDF файл, для этого мы будем использовать PDF ридер из библиотеки PyPDF2 . PyPDF2 - это...
Переменные, типы данных и операторы в Python
Переменные, типы данных и операторы в Python
В Python переменные используются как место для хранения значений. Пример переменной формы:
Почему Python - идеальный выбор для проекта AI и ML
Почему Python - идеальный выбор для проекта AI и ML
Блог, которым поделился Harikrishna Kundariya в нашем сообществе Developer Nation Community.
0
0
69
1
Перейти к ответу Данный вопрос помечен как решенный

Ответы 1

Ответ принят как подходящий

Это решается путем изменения модели на это:

class GraphClassifier(torch.nn.Module):
    def __init__(self, dataset, hidden_dim):
        super(GraphClassifier, self).__init__()
        self.conv1 = GCNConv(dataset.num_features, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, dataset.num_classes)

    def forward(self, data, do_visualize=False):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        if do_visualize:                          # NEW LINE
            visualize(x, data.y)                  # NEW LINE
        return F.log_softmax(x, dim=1)

Теперь, если вы вызовете прямую функцию с помощью do_visualize=Ture, она будет визуализирована. как это:

model = GraphClassifier(dataset, hidden_dim)
model.to(device)
model(dataset[0].to(device), do_visualize=True)

Другие вопросы по теме