research: Add pth-to-onnx.py

This commit is contained in:
Fam Zheng 2025-03-23 09:08:29 -07:00
parent bcc06bf255
commit 4b2431c1f0

34
research/pth-to-onnx.py Executable file
View File

@ -0,0 +1,34 @@
#!/usr/bin/env python3
import argparse
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("pth", type=str)
parser.add_argument("onnx", type=str)
return parser.parse_args()
def main():
args = parse_args()
# 加载训练好的模型
import torch
import torchvision.models as models
model = models.resnet18(pretrained=False)
model.fc = torch.nn.Linear(model.fc.in_features, 2) # 假设是二分类任务
model.load_state_dict(torch.load(args.pth)) # 加载训练好的权重
model.eval() # 设置为评估模式
# 创建一个随机输入张量(假设输入图像大小为 128x64
dummy_input = torch.randn(1, 3, 128, 64)
# 导出模型为 ONNX 格式
torch.onnx.export(
model, # 模型
dummy_input, # 输入张量
args.onnx, # 导出的 ONNX 文件名
input_names=['input'], # 输入节点名称
output_names=['output'], # 输出节点名称
dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}} # 支持动态批量大小
)
if __name__ == "__main__":
main()