Определить плоский слой в нейронной сети с помощью pytorch

Я пытаюсь определить плоский слой перед запуском полносвязного слоя. Поскольку мои входные данные представляют собой тензор формы (512, 2, 2), поэтому я хочу сгладить этот тензор перед слоями FC.

Раньше я получал эту ошибку:

empty(): argument 'size' must be tuple of ints, but found element of type Flatten at pos 2
import torch.nn as nn
class Network(nn.Module):
    def __init__(self):
        super(Network,self).__init__()
        self.flatten=nn.Flatten()
        self.fc1=nn.Linear(self.flatten,512)
        self.fc2=nn.Linear(512,256)
        self.fc3=nn.Linear(256,3)
 
        
    def forward(self,x):
        x=self.flatten(x) # Flatten layer
        x=torch.ReLU(self.fc1(x))  
        x=torch.ReLU(self.fc2(x))
        x=torch.softmax(self.fc3(x))
        return x

Не могли бы вы предоставить полную трассировку стека ошибок?

Ivan 06.05.2022 16:57

Пожалуйста, уточните вашу конкретную проблему или предоставьте дополнительную информацию, чтобы выделить именно то, что вам нужно. Как сейчас написано, трудно точно сказать, о чем вы спрашиваете.

Community 06.05.2022 17:17
Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
0
2
26
1
Перейти к ответу Данный вопрос помечен как решенный

Ответы 1

Ответ принят как подходящий

Эта строка неверна:

        self.fc1 = nn.Linear(self.flatten, 512)

первый аргумент in_features вместо nn.Linear должен быть int, а не nn.Module

в вашем случае вы определили атрибут flatten как модуль nn.Flatten:

        self.flatten = nn.Flatten()

чтобы решить эту проблему, вы должны передать in_features равно количеству функций после выравнивания:

        self.fc1 = nn.Linear(n_features_after_flatten, 512)

Другие вопросы по теме