research: Add pth-to-onnx.py
This commit is contained in:
parent
bcc06bf255
commit
4b2431c1f0
34
research/pth-to-onnx.py
Executable file
34
research/pth-to-onnx.py
Executable file
@ -0,0 +1,34 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("pth", type=str)
|
||||
parser.add_argument("onnx", type=str)
|
||||
return parser.parse_args()
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
# 加载训练好的模型
|
||||
import torch
|
||||
import torchvision.models as models
|
||||
model = models.resnet18(pretrained=False)
|
||||
model.fc = torch.nn.Linear(model.fc.in_features, 2) # 假设是二分类任务
|
||||
model.load_state_dict(torch.load(args.pth)) # 加载训练好的权重
|
||||
model.eval() # 设置为评估模式
|
||||
|
||||
# 创建一个随机输入张量(假设输入图像大小为 128x64)
|
||||
dummy_input = torch.randn(1, 3, 128, 64)
|
||||
|
||||
# 导出模型为 ONNX 格式
|
||||
torch.onnx.export(
|
||||
model, # 模型
|
||||
dummy_input, # 输入张量
|
||||
args.onnx, # 导出的 ONNX 文件名
|
||||
input_names=['input'], # 输入节点名称
|
||||
output_names=['output'], # 输出节点名称
|
||||
dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}} # 支持动态批量大小
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
x
Reference in New Issue
Block a user