themblem/research/pth-to-onnx.py
2025-03-25 13:29:57 +08:00

34 lines
1.1 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
import argparse
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("pth", type=str)
parser.add_argument("onnx", type=str)
return parser.parse_args()
def main():
args = parse_args()
# 加载训练好的模型
import torch
import torchvision.models as models
model = models.resnet18(pretrained=False)
model.fc = torch.nn.Linear(model.fc.in_features, 2) # 假设是二分类任务
model.load_state_dict(torch.load(args.pth)) # 加载训练好的权重
model.eval() # 设置为评估模式
# 创建一个随机输入张量(假设输入图像大小为 128x64
dummy_input = torch.randn(1, 3, 128, 64)
# 导出模型为 ONNX 格式
torch.onnx.export(
model, # 模型
dummy_input, # 输入张量
args.onnx, # 导出的 ONNX 文件名
input_names=['input'], # 输入节点名称
output_names=['output'], # 输出节点名称
dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}} # 支持动态批量大小
)
if __name__ == "__main__":
main()