research: Add pth-to-onnx.py
This commit is contained in:
parent
bcc06bf255
commit
4b2431c1f0
34
research/pth-to-onnx.py
Executable file
34
research/pth-to-onnx.py
Executable file
@ -0,0 +1,34 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("pth", type=str)
|
||||||
|
parser.add_argument("onnx", type=str)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
# 加载训练好的模型
|
||||||
|
import torch
|
||||||
|
import torchvision.models as models
|
||||||
|
model = models.resnet18(pretrained=False)
|
||||||
|
model.fc = torch.nn.Linear(model.fc.in_features, 2) # 假设是二分类任务
|
||||||
|
model.load_state_dict(torch.load(args.pth)) # 加载训练好的权重
|
||||||
|
model.eval() # 设置为评估模式
|
||||||
|
|
||||||
|
# 创建一个随机输入张量(假设输入图像大小为 128x64)
|
||||||
|
dummy_input = torch.randn(1, 3, 128, 64)
|
||||||
|
|
||||||
|
# 导出模型为 ONNX 格式
|
||||||
|
torch.onnx.export(
|
||||||
|
model, # 模型
|
||||||
|
dummy_input, # 输入张量
|
||||||
|
args.onnx, # 导出的 ONNX 文件名
|
||||||
|
input_names=['input'], # 输入节点名称
|
||||||
|
output_names=['output'], # 输出节点名称
|
||||||
|
dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}} # 支持动态批量大小
|
||||||
|
)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Loading…
x
Reference in New Issue
Block a user