34 lines
1.1 KiB
Python
Executable File
34 lines
1.1 KiB
Python
Executable File
#!/usr/bin/env python3
|
||
import argparse
|
||
|
||
def parse_args():
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument("pth", type=str)
|
||
parser.add_argument("onnx", type=str)
|
||
return parser.parse_args()
|
||
|
||
def main():
|
||
args = parse_args()
|
||
# 加载训练好的模型
|
||
import torch
|
||
import torchvision.models as models
|
||
model = models.resnet18(pretrained=False)
|
||
model.fc = torch.nn.Linear(model.fc.in_features, 2) # 假设是二分类任务
|
||
model.load_state_dict(torch.load(args.pth)) # 加载训练好的权重
|
||
model.eval() # 设置为评估模式
|
||
|
||
# 创建一个随机输入张量(假设输入图像大小为 128x64)
|
||
dummy_input = torch.randn(1, 3, 128, 64)
|
||
|
||
# 导出模型为 ONNX 格式
|
||
torch.onnx.export(
|
||
model, # 模型
|
||
dummy_input, # 输入张量
|
||
args.onnx, # 导出的 ONNX 文件名
|
||
input_names=['input'], # 输入节点名称
|
||
output_names=['output'], # 输出节点名称
|
||
dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}} # 支持动态批量大小
|
||
)
|
||
|
||
if __name__ == "__main__":
|
||
main() |