Тестовый код:
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear = nn.Linear(32, 16)
self.relu1 = nn.ReLU(inplace=True)
self.relu2 = nn.ReLU(inplace=True)
self.fc = nn.Linear(32, 2)
def forward(self, x):
x1, x2 = x
x1 = self.linear(x1)
x1 = self.relu1(x1)
x2 = self.linear(x2)
x2 = self.relu2(x2)
out = torch.cat((x1, x2), dim=-1)
out = self.fc(out)
return out
model = Model()
model.eval()
x1 = torch.randn((2, 10, 32))
x2 = torch.randn((2, 10, 32))
x = (x1, x2)
torch.onnx.export(model,
x,
'model.onnx',
input_names=["input"],
output_names=["output"],
dynamic_axes = {'input': {0: 'batch'}, 'output': {0: 'batch'}}
)
print("Done")
Как преобразовать приведенный выше код в onnx? Вход прямой части моей модели - это кортеж, нельзя преобразовать в формат onnx? Спасибо! Вход прямой части моей модели — это кортеж, который не может быть преобразован в формат onnx существующими методами. Можете ли вы сказать мне, как это решить
Глядя на Эта проблема и это другая проблема, параметры распаковываются по умолчанию, поэтому вам нужно указать кортеж в качестве аргумента для torch.onnx.export
:
torch.onnx.export(model,
args=(x,),
f='model.onnx',
input_names=["input"],
output_names=["output"],
dynamic_axes = {'input': {0: 'batch'}, 'output': {0: 'batch'}})
О, я пропустил вас, я отредактировал свой ответ.
Спасибо за ваш ответ. Я понимаю метод, который вы предоставляете, но я хочу знать, как преобразовать модель в формат onnx, когда прямой ввод является кортежем или списком. Или torch.onnx.export сейчас не поддерживает этот тип работы.