From 4b2431c1f0d2b363706ea8b2494b6d8a0bcfed2e Mon Sep 17 00:00:00 2001 From: Fam Zheng Date: Sun, 23 Mar 2025 09:08:29 -0700 Subject: [PATCH] research: Add pth-to-onnx.py --- research/pth-to-onnx.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100755 research/pth-to-onnx.py diff --git a/research/pth-to-onnx.py b/research/pth-to-onnx.py new file mode 100755 index 0000000..a2b7bf2 --- /dev/null +++ b/research/pth-to-onnx.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 +import argparse + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("pth", type=str) + parser.add_argument("onnx", type=str) + return parser.parse_args() + +def main(): + args = parse_args() + # 加载训练好的模型 + import torch + import torchvision.models as models + model = models.resnet18(pretrained=False) + model.fc = torch.nn.Linear(model.fc.in_features, 2) # 假设是二分类任务 + model.load_state_dict(torch.load(args.pth)) # 加载训练好的权重 + model.eval() # 设置为评估模式 + + # 创建一个随机输入张量(假设输入图像大小为 128x64) + dummy_input = torch.randn(1, 3, 128, 64) + + # 导出模型为 ONNX 格式 + torch.onnx.export( + model, # 模型 + dummy_input, # 输入张量 + args.onnx, # 导出的 ONNX 文件名 + input_names=['input'], # 输入节点名称 + output_names=['output'], # 输出节点名称 + dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}} # 支持动态批量大小 + ) + +if __name__ == "__main__": + main() \ No newline at end of file