Преобразование модели PyTorch ONNX в движок TensorRT - Jetson Orin Nano

Я пытаюсь преобразовать модель ViT-B/32 Vision Transformer из репозитория UNICOM на Jetson Orin Nano. Класс и исходный код модели Vision Transformer находятся здесь.

Я использую следующий код для преобразования модели в ONNX:

import torch
import onnx
import onnxruntime

from unicom.vision_transformer import build_model

if __name__ == '__main__':
    model_name = "ViT-B/32"
    model_name_fp16 = "FP16-ViT-B-32"
    onnx_model_path = f"{model_name_fp16}.onnx"

    model = build_model(model_name)
    model.eval()
    model = model.to('cuda')
    torch_input = torch.randn(1, 3, 224, 224).to('cuda')

    onnx_program = torch.onnx.dynamo_export(model, torch_input)
    onnx_program.save(onnx_model_path)

    onnx_model = onnx.load(onnx_model_path)
    onnx.checker.check_model(onnx_model_path)

Затем я использую следующую командную строку для преобразования модели ONNX в движок TensorRT:

/usr/src/tensorrt/bin/trtexec --onnx=FP16-ViT-B-32.onnx --saveEngine=FP16-ViT-B-32.trt --workspace=1024 --fp16

Это приводит к следующей ошибке:

--workspace flag has been deprecated by --memPoolSize flag.
=== Model Options ===
Format: ONNX
Model: /home/jetson/HPS/Models/FeatureExtractor/UNICOM/ONNX/FP16-ViT-B-32.onnx
Output:
=== Build Options ===
Max batch: explicit batch
Memory Pools: workspace: 1024 MiB, dlaSRAM: default, dlaLocalDRAM: default, dlaGlobalDRAM: default
minTiming: 1
avgTiming: 8
Precision: FP32+FP16
LayerPrecisions:
Layer Device Types:
Calibration:
Refit: Disabled
Version Compatible: Disabled
ONNX Native InstanceNorm: Disabled
TensorRT runtime: full
Lean DLL Path:
Tempfile Controls: { in_memory: allow, temporary: allow }
Exclude Lean Runtime: Disabled
Sparsity: Disabled
Safe mode: Disabled
Build DLA standalone loadable: Disabled
Allow GPU fallback for DLA: Disabled
DirectIO mode: Disabled
Restricted mode: Disabled
Skip inference: Disabled
Save engine: /home/jetson/HPS/Models/FeatureExtractor/UNICOM/ONNX/FP16-ViT-B-32.trt
Load engine:
Profiling verbosity: 0
Tactic sources: Using default tactic sources
timingCacheMode: local
timingCacheFile:
Heuristic: Disabled
Preview Features: Use default preview flags.
MaxAuxStreams: -1
BuilderOptimizationLevel: -1
Input(s)s format: fp32:CHW
Output(s)s format: fp32:CHW
Input build shapes: model
Input calibration shapes: model
=== System Options ===
Device: 0
DLACore:
Plugins:
setPluginsToSerialize:
dynamicPlugins:
ignoreParsedPluginLibs: 0
=== Inference Options ===
Batch: Explicit
Input inference shapes: model
Iterations: 10
Duration: 3s (+ 200ms warm up)
Sleep time: 0ms
Idle time: 0ms
Inference Streams: 1
ExposeDMA: Disabled
Data transfers: Enabled
Spin-wait: Disabled
Multithreading: Disabled
CUDA Graph: Disabled
Separate profiling: Disabled
Time Deserialize: Disabled
Time Refit: Disabled
NVTX verbosity: 0
Persistent Cache Ratio: 0
Inputs:
=== Reporting Options ===
Verbose: Disabled
Averages: 10 inferences
Percentiles: 90,95,99
Dump refittable layers:Disabled
Dump output: Disabled
Profile: Disabled
Export timing to JSON file:
Export output to JSON file:
Export profile to JSON file:
=== Device Information ===
Selected Device: Orin
Compute Capability: 8.7
SMs: 8
Device Global Memory: 7620 MiB
Shared Memory per SM: 164 KiB
Memory Bus Width: 128 bits (ECC disabled)
Application Compute Clock Rate: 0.624 GHz
Application Memory Clock Rate: 0.624 GHz
Note: The application clock rates do not reflect the actual clock rates that the GPU is currently running at.
TensorRT version: 8.6.2
Loading standard plugins
[MemUsageChange] Init CUDA: CPU +2, GPU +0, now: CPU 33, GPU 4508 (MiB)
[MemUsageChange] Init builder kernel library: CPU +1154, GPU +1351, now: CPU 1223, GPU 5866 (MiB)
Start parsing network model.
----------------------------------------------------------------
Input filename:   /home/jetson/HPS/Models/FeatureExtractor/UNICOM/ONNX/FP16-ViT-B-32.onnx
ONNX IR version:  0.0.8
Opset version:    1
Producer name:    pytorch
Producer version: 2.3.0
Domain:
Model version:    0
Doc string:
----------------------------------------------------------------
No importer registered for op: unicom_vision_transformer_PatchEmbedding_patch_embed_1. Attempting to import as plugin.
Searching for plugin: unicom_vision_transformer_PatchEmbedding_patch_embed_1, plugin_version: 1, plugin_namespace:
3: getPluginCreator could not find plugin: unicom_vision_transformer_PatchEmbedding_patch_embed_1 version: 1
ModelImporter.cpp:768: While parsing node number 0 [unicom_vision_transformer_PatchEmbedding_patch_embed_1 -> "patch_embed_1"]:
ModelImporter.cpp:769: --- Begin node ---
ModelImporter.cpp:770: input: "l_x_"
--workspace flag has been deprecated by --memPoolSize flag.
=== Model Options ===
Format: ONNX
Model: /home/jetson/HPS/Models/FeatureExtractor/UNICOM/ONNX/FP16-ViT-B-32.onnx
Output:
=== Build Options ===
Max batch: explicit batch
Memory Pools: workspace: 1024 MiB, dlaSRAM: default, dlaLocalDRAM: default, dlaGlobalDRAM: default
minTiming: 1
avgTiming: 8
Precision: FP32+FP16
LayerPrecisions:
Layer Device Types:
Calibration:
Refit: Disabled
Version Compatible: Disabled
ONNX Native InstanceNorm: Disabled
TensorRT runtime: full
Lean DLL Path:
Tempfile Controls: { in_memory: allow, temporary: allow }
Exclude Lean Runtime: Disabled
Sparsity: Disabled
Safe mode: Disabled
Build DLA standalone loadable: Disabled
Allow GPU fallback for DLA: Disabled
DirectIO mode: Disabled
Restricted mode: Disabled
Skip inference: Disabled
Save engine: /home/jetson/HPS/Models/FeatureExtractor/UNICOM/ONNX/FP16-ViT-B-32.trt
Load engine:
Profiling verbosity: 0
Tactic sources: Using default tactic sources
timingCacheMode: local
timingCacheFile:
Heuristic: Disabled
Preview Features: Use default preview flags.
MaxAuxStreams: -1
BuilderOptimizationLevel: -1
Input(s)s format: fp32:CHW
Output(s)s format: fp32:CHW
Input build shapes: model
Input calibration shapes: model
=== System Options ===
Device: 0
DLACore:
Plugins:
setPluginsToSerialize:
dynamicPlugins:
ignoreParsedPluginLibs: 0
=== Inference Options ===
Batch: Explicit
Input inference shapes: model
Iterations: 10
Duration: 3s (+ 200ms warm up)
Sleep time: 0ms
Idle time: 0ms
Inference Streams: 1
ExposeDMA: Disabled
Data transfers: Enabled
Spin-wait: Disabled
Multithreading: Disabled
CUDA Graph: Disabled
Separate profiling: Disabled
Time Deserialize: Disabled
Time Refit: Disabled
NVTX verbosity: 0
Persistent Cache Ratio: 0
Inputs:
=== Reporting Options ===
Verbose: Enabled
Averages: 10 inferences
Percentiles: 90,95,99
Dump refittable layers:Disabled
Dump output: Disabled
Profile: Disabled
Export timing to JSON file:
Export output to JSON file:
Export profile to JSON file:
=== Device Information ===
Selected Device: Orin
Compute Capability: 8.7
SMs: 8
Device Global Memory: 7620 MiB
Shared Memory per SM: 164 KiB
Memory Bus Width: 128 bits (ECC disabled)
Application Compute Clock Rate: 0.624 GHz
Application Memory Clock Rate: 0.624 GHz
Note: The application clock rates do not reflect the actual clock rates that the GPU is currently running at.
TensorRT version: 8.6.2
Loading standard plugins
Registered plugin - ::BatchedNMSDynamic_TRT version 1
Registered plugin - ::BatchedNMS_TRT version 1
Registered plugin - ::BatchTilePlugin_TRT version 1
Registered plugin - ::Clip_TRT version 1
Registered plugin - ::CoordConvAC version 1
Registered plugin - ::CropAndResizeDynamic version 1
Registered plugin - ::CropAndResize version 1
Registered plugin - ::DecodeBbox3DPlugin version 1
Registered plugin - ::DetectionLayer_TRT version 1
Registered plugin - ::EfficientNMS_Explicit_TF_TRT version 1
Registered plugin - ::EfficientNMS_Implicit_TF_TRT version 1
Registered plugin - ::EfficientNMS_ONNX_TRT version 1
Registered plugin - ::EfficientNMS_TRT version 1
Registered plugin - ::FlattenConcat_TRT version 1
Registered plugin - ::GenerateDetection_TRT version 1
Registered plugin - ::GridAnchor_TRT version 1
Registered plugin - ::GridAnchorRect_TRT version 1
Registered plugin - ::InstanceNormalization_TRT version 1
Registered plugin - ::InstanceNormalization_TRT version 2
Registered plugin - ::LReLU_TRT version 1
Registered plugin - ::ModulatedDeformConv2d version 1
Registered plugin - ::MultilevelCropAndResize_TRT version 1
Registered plugin - ::MultilevelProposeROI_TRT version 1
Registered plugin - ::MultiscaleDeformableAttnPlugin_TRT version 1
Registered plugin - ::NMSDynamic_TRT version 1
Registered plugin - ::NMS_TRT version 1
Registered plugin - ::Normalize_TRT version 1
Registered plugin - ::PillarScatterPlugin version 1
Registered plugin - ::PriorBox_TRT version 1
Registered plugin - ::ProposalDynamic version 1
Registered plugin - ::ProposalLayer_TRT version 1
Registered plugin - ::Proposal version 1
Registered plugin - ::PyramidROIAlign_TRT version 1
Registered plugin - ::Region_TRT version 1
Registered plugin - ::Reorg_TRT version 1
Registered plugin - ::ResizeNearest_TRT version 1
Registered plugin - ::ROIAlign_TRT version 1
Registered plugin - ::RPROI_TRT version 1
Registered plugin - ::ScatterND version 1
Registered plugin - ::SpecialSlice_TRT version 1
Registered plugin - ::Split version 1
Registered plugin - ::VoxelGeneratorPlugin version 1
[MemUsageChange] Init CUDA: CPU +2, GPU +0, now: CPU 33, GPU 5167 (MiB)
Trying to load shared library libnvinfer_builder_resource.so.8.6.2
Loaded shared library libnvinfer_builder_resource.so.8.6.2
[MemUsageChange] Init builder kernel library: CPU +1154, GPU +995, now: CPU 1223, GPU 6203 (MiB)
CUDA lazy loading is enabled.
Start parsing network model.
----------------------------------------------------------------
Input filename:   /home/jetson/HPS/Models/FeatureExtractor/UNICOM/ONNX/FP16-ViT-B-32.onnx
ONNX IR version:  0.0.8
Opset version:    1
Producer name:    pytorch
Producer version: 2.3.0
Domain:
Model version:    0
Doc string:
----------------------------------------------------------------
Plugin already registered - ::BatchedNMSDynamic_TRT version 1
Plugin already registered - ::BatchedNMS_TRT version 1
Plugin already registered - ::BatchTilePlugin_TRT version 1
Plugin already registered - ::Clip_TRT version 1
Plugin already registered - ::CoordConvAC version 1
Plugin already registered - ::CropAndResizeDynamic version 1
Plugin already registered - ::CropAndResize version 1
Plugin already registered - ::DecodeBbox3DPlugin version 1
Plugin already registered - ::DetectionLayer_TRT version 1
Plugin already registered - ::EfficientNMS_Explicit_TF_TRT version 1
Plugin already registered - ::EfficientNMS_Implicit_TF_TRT version 1
Plugin already registered - ::EfficientNMS_ONNX_TRT version 1
Plugin already registered - ::EfficientNMS_TRT version 1
Plugin already registered - ::FlattenConcat_TRT version 1
Plugin already registered - ::GenerateDetection_TRT version 1
Plugin already registered - ::GridAnchor_TRT version 1
Plugin already registered - ::GridAnchorRect_TRT version 1
Plugin already registered - ::InstanceNormalization_TRT version 1
Plugin already registered - ::InstanceNormalization_TRT version 2
Plugin already registered - ::LReLU_TRT version 1
Plugin already registered - ::ModulatedDeformConv2d version 1
Plugin already registered - ::MultilevelCropAndResize_TRT version 1
Plugin already registered - ::MultilevelProposeROI_TRT version 1
Plugin already registered - ::MultiscaleDeformableAttnPlugin_TRT version 1
Plugin already registered - ::NMSDynamic_TRT version 1
Plugin already registered - ::NMS_TRT version 1
Plugin already registered - ::Normalize_TRT version 1
Plugin already registered - ::PillarScatterPlugin version 1
Plugin already registered - ::PriorBox_TRT version 1
Plugin already registered - ::ProposalDynamic version 1
Plugin already registered - ::ProposalLayer_TRT version 1
Plugin already registered - ::Proposal version 1
Plugin already registered - ::PyramidROIAlign_TRT version 1
Plugin already registered - ::Region_TRT version 1
Plugin already registered - ::Reorg_TRT version 1
Plugin already registered - ::ResizeNearest_TRT version 1
Plugin already registered - ::ROIAlign_TRT version 1
Plugin already registered - ::RPROI_TRT version 1
Plugin already registered - ::ScatterND version 1
Plugin already registered - ::SpecialSlice_TRT version 1
Plugin already registered - ::Split version 1
Plugin already registered - ::VoxelGeneratorPlugin version 1
Adding network input: l_x_ with dtype: float32, dimensions: (1, 3, 224, 224)
Registering tensor: l_x_ for ONNX tensor: l_x_
Importing : patch_embed.proj.weight
Importing : patch_embed.proj.bias
Importing : pos_embed
Importing : blocks.0.norm1.weight
Importing : blocks.0.norm1.bias
Importing : blocks.0.attn.qkv.weight
Importing : blocks.0.attn.proj.weight
Importing : blocks.0.attn.proj.bias
Importing : blocks.0.norm2.weight
Importing : blocks.0.norm2.bias
Importing : blocks.0.mlp.fc1.weight
Importing : blocks.0.mlp.fc1.bias
Importing : blocks.0.mlp.fc2.weight
Importing : blocks.0.mlp.fc2.bias
Importing : blocks.1.norm1.weight
Importing : blocks.1.norm1.bias
Importing : blocks.1.attn.qkv.weight
Importing : blocks.1.attn.proj.weight
Importing : blocks.1.attn.proj.bias
Importing : blocks.1.norm2.weight
Importing : blocks.1.norm2.bias
Importing : blocks.1.mlp.fc1.weight
Importing : blocks.1.mlp.fc1.bias
Importing : blocks.1.mlp.fc2.weight
Importing : blocks.1.mlp.fc2.bias
Importing : blocks.2.norm1.weight
Importing : blocks.2.norm1.bias
Importing : blocks.2.attn.qkv.weight
Importing : blocks.2.attn.proj.weight
Importing : blocks.2.attn.proj.bias
Importing : blocks.2.norm2.weight
Importing : blocks.2.norm2.bias
Importing : blocks.2.mlp.fc1.weight
Importing : blocks.2.mlp.fc1.bias
Importing : blocks.2.mlp.fc2.weight
Importing : blocks.2.mlp.fc2.bias
Importing : blocks.3.norm1.weight
Importing : blocks.3.norm1.bias
Importing : blocks.3.attn.qkv.weight
Importing : blocks.3.attn.proj.weight
Importing : blocks.3.attn.proj.bias
Importing : blocks.3.norm2.weight
Importing : blocks.3.norm2.bias
Importing : blocks.3.mlp.fc1.weight
Importing : blocks.3.mlp.fc1.bias
Importing : blocks.3.mlp.fc2.weight
Importing : blocks.3.mlp.fc2.bias
Importing : blocks.4.norm1.weight
Importing : blocks.4.norm1.bias
Importing : blocks.4.attn.qkv.weight
Importing : blocks.4.attn.proj.weight
Importing : blocks.4.attn.proj.bias
Importing : blocks.4.norm2.weight
Importing : blocks.4.norm2.bias
Importing : blocks.4.mlp.fc1.weight
Importing : blocks.4.mlp.fc1.bias
Importing : blocks.4.mlp.fc2.weight
Importing : blocks.4.mlp.fc2.bias
Importing : blocks.5.norm1.weight
Importing : blocks.5.norm1.bias
Importing : blocks.5.attn.qkv.weight
Importing : blocks.5.attn.proj.weight
Importing : blocks.5.attn.proj.bias
Importing : blocks.5.norm2.weight
Importing : blocks.5.norm2.bias
Importing : blocks.5.mlp.fc1.weight
Importing : blocks.5.mlp.fc1.bias
Importing : blocks.5.mlp.fc2.weight
Importing : blocks.5.mlp.fc2.bias
Importing : blocks.6.norm1.weight
Importing : blocks.6.norm1.bias
Importing : blocks.6.attn.qkv.weight
Importing : blocks.6.attn.proj.weight
Importing : blocks.6.attn.proj.bias
Importing : blocks.6.norm2.weight
Importing : blocks.6.norm2.bias
Importing : blocks.6.mlp.fc1.weight
Importing : blocks.6.mlp.fc1.bias
Importing : blocks.6.mlp.fc2.weight
Importing : blocks.6.mlp.fc2.bias
Importing : blocks.7.norm1.weight
Importing : blocks.7.norm1.bias
Importing : blocks.7.attn.qkv.weight
Importing : blocks.7.attn.proj.weight
Importing : blocks.7.attn.proj.bias
Importing : blocks.7.norm2.weight
Importing : blocks.7.norm2.bias
Importing : blocks.7.mlp.fc1.weight
Importing : blocks.7.mlp.fc1.bias
Importing : blocks.7.mlp.fc2.weight
Importing : blocks.7.mlp.fc2.bias
Importing : blocks.8.norm1.weight
Importing : blocks.8.norm1.bias
Importing : blocks.8.attn.qkv.weight
Importing : blocks.8.attn.proj.weight
Importing : blocks.8.attn.proj.bias
Importing : blocks.8.norm2.weight
Importing : blocks.8.norm2.bias
Importing : blocks.8.mlp.fc1.weight
Importing : blocks.8.mlp.fc1.bias
Importing : blocks.8.mlp.fc2.weight
Importing : blocks.8.mlp.fc2.bias
Importing : blocks.9.norm1.weight
Importing : blocks.9.norm1.bias
Importing : blocks.9.attn.qkv.weight
Importing : blocks.9.attn.proj.weight
Importing : blocks.9.attn.proj.bias
Importing : blocks.9.norm2.weight
Importing : blocks.9.norm2.bias
Importing : blocks.9.mlp.fc1.weight
Importing : blocks.9.mlp.fc1.bias
Importing : blocks.9.mlp.fc2.weight
Importing : blocks.9.mlp.fc2.bias
Importing : blocks.10.norm1.weight
Importing : blocks.10.norm1.bias
Importing : blocks.10.attn.qkv.weight
Importing : blocks.10.attn.proj.weight
Importing : blocks.10.attn.proj.bias
Importing : blocks.10.norm2.weight
Importing : blocks.10.norm2.bias
Importing : blocks.10.mlp.fc1.weight
Importing : blocks.10.mlp.fc1.bias
Importing : blocks.10.mlp.fc2.weight
Importing : blocks.10.mlp.fc2.bias
Importing : blocks.11.norm1.weight
Importing : blocks.11.norm1.bias
Importing : blocks.11.attn.qkv.weight
Importing : blocks.11.attn.proj.weight
Importing : blocks.11.attn.proj.bias
Importing : blocks.11.norm2.weight
Importing : blocks.11.norm2.bias
Importing : blocks.11.mlp.fc1.weight
Importing : blocks.11.mlp.fc1.bias
Importing : blocks.11.mlp.fc2.weight
Importing : blocks.11.mlp.fc2.bias
Importing : norm.weight
Importing : norm.bias
Importing : feature.0.weight
Importing : feature.1.weight
Importing : feature.1.bias
Importing : feature.1.running_mean
Importing : feature.1.running_var
Importing : feature.2.weight
Importing : feature.3.weight
Importing : feature.3.bias
Importing : feature.3.running_mean
Importing : feature.3.running_var
Parsing node: unicom_vision_transformer_PatchEmbedding_patch_embed_1_1 [unicom_vision_transformer_PatchEmbedding_patch_embed_1]
Searching for input: l_x_
Searching for input: patch_embed.proj.weight
Searching for input: patch_embed.proj.bias
unicom_vision_transformer_PatchEmbedding_patch_embed_1_1 [unicom_vision_transformer_PatchEmbedding_patch_embed_1] inputs: [l_x_ -> (1, 3, 224, 224)[FLOAT]], [patch_embed.proj.weight -> (768, 3, 32, 32)[FLOAT]], [patch_embed.proj.bias -> (768)[FLOAT]],
No importer registered for op: unicom_vision_transformer_PatchEmbedding_patch_embed_1. Attempting to import as plugin.
Searching for plugin: unicom_vision_transformer_PatchEmbedding_patch_embed_1, plugin_version: 1, plugin_namespace:
Local registry did not find unicom_vision_transformer_PatchEmbedding_patch_embed_1 creator. Will try parent registry if enabled.
Global registry did not find unicom_vision_transformer_PatchEmbedding_patch_embed_1 creator. Will try parent registry if enabled.
3: getPluginCreator could not find plugin: unicom_vision_transformer_PatchEmbedding_patch_embed_1 version: 1
ModelImporter.cpp:768: While parsing node number 0 [unicom_vision_transformer_PatchEmbedding_patch_embed_1 -> "patch_embed_1"]:
ModelImporter.cpp:769: --- Begin node ---
ModelImporter.cpp:770: input: "l_x_"
input: "patch_embed.proj.weight"
input: "patch_embed.proj.bias"
output: "patch_embed_1"
name: "unicom_vision_transformer_PatchEmbedding_patch_embed_1_1"
op_type: "unicom_vision_transformer_PatchEmbedding_patch_embed_1"
doc_string: ""
domain: "pkg.unicom"

input: "patch_embed.proj.weight"
input: "patch_embed.proj.bias"
output: "patch_embed_1"
name: "unicom_vision_transformer_PatchEmbedding_patch_embed_1_1"
op_type: "unicom_vision_transformer_PatchEmbedding_patch_embed_1"
doc_string: ""
domain: "pkg.unicom"

[E] ModelImporter.cpp:771: --- End node ---
[E] ModelImporter.cpp:773: ERROR: builtin_op_importers.cpp:5403 In function importFallbackPluginImporter:
[E] ModelImporter.cpp:771: --- End node ---
[E] ModelImporter.cpp:773: ERROR: builtin_op_importers.cpp:5403 In function importFallbackPluginImporter:
[8] Assertion failed: creator && "Plugin not found, are the plugin name, version, and namespace correct?"
[8] Assertion failed: creator && "Plugin not found, are the plugin name, version, and namespace correct?"
[E] Failed to parse onnx file
[I] Finished parsing network model. Parse time: 4.99544
[E] Parsing model failed
[E] Failed to create engine from model or file.
[E] Engine set up failed
[E] Failed to parse onnx file
[I] Finished parsing network model. Parse time: 13.1481
[E] Parsing model failed
[E] Failed to create engine from model or file.
[E] Engine set up failed

Кажется, проблема возникает из-за класса PatchEmbeddingздесь, и не похоже, что модель использует какие-либо необычные методы и слои, которые не могут быть конвертированы TensorRT. Вот исходный код класса:

class PatchEmbedding(nn.Module):
    def __init__(self, input_size=224, patch_size=32, in_channels: int = 3, dim: int = 768):
        super().__init__()
        if isinstance(input_size, int):
            input_size = (input_size, input_size)
        if isinstance(patch_size, int):
            patch_size = (patch_size, patch_size)
        H = input_size[0] // patch_size[0]
        W = input_size[1] // patch_size[1]
        self.num_patches = H * W
        self.proj = nn.Conv2d(
            in_channels, dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x

Что мне нужно сделать, чтобы модель можно было конвертировать в TensorRT?

## Environment

**TensorRT Version**:  tensorrt_version_8_6_2_3

**GPU Type**: Jetson Orin Nano

**Nvidia Driver Version**:

**CUDA Version**: 12.2

**CUDNN Version**:  8.9.4.25-1+cuda12.2

**Operating System + Version**: Jetpack 6.0

**Python Version (if applicable)**: 3.10

**PyTorch Version (if applicable)**:  2.3.0

**ONNX Version (if applicable)**:  1.16.1

**onnxruntime-gpu Version (if applicable)**:  1.17.0

**onnxscript Version (if applicable)**:  0.1.0.dev20240721

ОБНОВЛЯТЬ: Запуск кода с использованием torch.onnx.export вместо torch.onnx.dynamo_export дает следующую ошибку:

/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension:
      warn(f"Failed to load image Python extension: {e}")
    /home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/_dynamo/external_utils.py:36: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.
      return fn(*args, **kwargs)
    Traceback (most recent call last):
      File "/home/jetson/HPS/Scripts_Utilities/ONNX/HPS_ExportModelToONNX.py", line 31, in <module>
        torch.onnx.export(model, torch_input,onnx_model_path)
      File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/onnx/utils.py", line 516, in export
        _export(
      File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/onnx/utils.py", line 1612, in _export
        graph, params_dict, torch_out = _model_to_graph(
      File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/onnx/utils.py", line 1134, in _model_to_graph
        graph, params, torch_out, module = _create_jit_graph(model, args)
      File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/onnx/utils.py", line 1010, in _create_jit_graph
        graph, torch_out = _trace_and_get_graph_from_model(model, args)
      File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/onnx/utils.py", line 914, in _trace_and_get_graph_from_model
        trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
      File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/jit/_trace.py", line 1315, in _get_trace_graph
        outs = ONNXTracedModule(
      File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
        return self._call_impl(*args, **kwargs)
      File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
        return forward_call(*args, **kwargs)
      File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/jit/_trace.py", line 141, in forward
        graph, out = torch._C._create_graph_by_tracing(
      File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/jit/_trace.py", line 132, in wrapper
        outs.append(self.inner(*trace_inputs))
      File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
        return self._call_impl(*args, **kwargs)
      File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
        return forward_call(*args, **kwargs)
      File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1522, in _slow_forward
        result = self.forward(*input, **kwargs)
      File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/unicom/vision_transformer.py", line 57, in forward
        x = self.forward_features(x)
      File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/unicom/vision_transformer.py", line 52, in forward_features
        x = func(x)
      File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
        return self._call_impl(*args, **kwargs)
      File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
        return forward_call(*args, **kwargs)
      File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1522, in _slow_forward
        result = self.forward(*input, **kwargs)
      File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/unicom/vision_transformer.py", line 122, in forward
        return checkpoint(self.forward_impl, x)
      File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner
        return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
      File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 403, in _fn
        return fn(*args, **kwargs)
      File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 36, in inner
        return fn(*args, **kwargs)
      File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 481, in checkpoint
        return CheckpointFunction.apply(function, preserve, *args)
      File "/home/jetson/miniconda3/envs/HPS/lib/python3.10/site-packages/torch/autograd/function.py", line 571, in apply
        return super().apply(*args, **kwargs)  # type: ignore[misc]
    RuntimeError: _Map_base::at
torch.onnx.dynamo_export все еще находится в стадии бета-тестирования, вы пробовали более старую версию TS на основе torch.onnx.export?
simeonovich 25.07.2024 12:01

Можете ли вы дать мне ссылку на этот подход?

Cypher 26.07.2024 20:40
Ссылка на документацию torch - в основном то же самое, только с другим вызовом функции.
simeonovich 26.07.2024 22:18

@simeonovich Я обновил вопрос, добавив новую информацию об использовании torch.onnx.export вместо torch.onnx.dynamo_export. Спасибо.

Cypher 27.07.2024 12:18
Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
1
4
72
1
Перейти к ответу Данный вопрос помечен как решенный

Ответы 1

Ответ принят как подходящий

Тот факт, что trtexec не работает на PatchEmbedding, определенно странен, учитывая, что весь модуль представляет собой буквально свертку. Совершенно не понимаю, что здесь не так.

Что касается вашего исключения при использовании устаревшего экспортера TorchScript (torch.onnx.export), то в данном случае он не работает из-за того, что части преобразователя обернуты в torch.utils.checkpoint, который, по-видимому, не поддерживает трассировку. Если мы удалим контрольную точку, инициализировав VisionTransformer с помощью using_checkpoint=False, модель экспортируется без каких-либо ошибок.

import torch
from unicom.vision_transformer import VisionTransformer

model_name_fp16 = "FP16-ViT-B-32"
onnx_model_path = f"{model_name_fp16}.onnx"
model = VisionTransformer(
    input_size=224,
    patch_size=32,
    in_channels=3,
    dim=768,
    embedding_size=512,
    depth=12,
    num_heads=12,
    drop_path_rate=0.1,
    using_checkpoint=False, # default value of True breaks torch.onnx.export
)
model.eval()
model = model.to('cuda')
torch_input = torch.randn(1, 3, 224, 224).to('cuda')
# using TorchScript export instead of TorchDynamo
onnx_program = torch.onnx.export(model, torch_input, onnx_model_path)

И полученная модель ONNX успешно компилируется в движок TensorRT.

>>> trtexec --onnx=FP16-ViT-B-32.onnx --fp16
...
>>> [I] Engine built in 24.711 sec.
>>> &&&& PASSED TensorRT.trtexec [TensorRT v100200]

Вышеупомянутое работало с использованием torch==2.2.0 и TensorRT==10.2.

Спасибо. В настоящее время пытаюсь преобразовать его в файл Engine на Jetson Orin Nano, но кажется, что VRAM на устройстве не может справиться с большим размером модели ViT/L-14. Можно ли что-нибудь сделать, чтобы можно было конвертировать модель даже на устройствах с низким VRAM?🤔

Cypher 29.07.2024 13:56

Вам не хватает памяти, даже если VIT-L/14 экспортируется с размером пакета 1? Это странно, в моей системе пиковое значение VRAM составляет около 3-4 ГБ на этапе сборки. Я бы попробовал две вещи. Во-первых, может помочь обновление TensorRT до более новой версии: я видел существенные различия в объеме памяти для одних и тех же моделей даже между второстепенными выпусками. Во-вторых, попробуйте удалить менеджер контекста autocast(False) из класса unicom Attention. Я не знаю, как именно torch.amp обрабатывается во время преобразования модели, но это может помешать TensorRT использовать fp16 ops.

simeonovich 29.07.2024 18:05

Другие вопросы по теме