#!/usr/bin/env python3 import argparse def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("pth", type=str) parser.add_argument("onnx", type=str) return parser.parse_args() def main(): args = parse_args() # 加载训练好的模型 import torch import torchvision.models as models model = models.resnet18(pretrained=False) model.fc = torch.nn.Linear(model.fc.in_features, 2) # 假设是二分类任务 model.load_state_dict(torch.load(args.pth)) # 加载训练好的权重 model.eval() # 设置为评估模式 # 创建一个随机输入张量(假设输入图像大小为 128x64) dummy_input = torch.randn(1, 3, 128, 64) # 导出模型为 ONNX 格式 torch.onnx.export( model, # 模型 dummy_input, # 输入张量 args.onnx, # 导出的 ONNX 文件名 input_names=['input'], # 输入节点名称 output_names=['output'], # 输出节点名称 dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}} # 支持动态批量大小 ) if __name__ == "__main__": main()