drop alg roi files

This commit is contained in:
Fam Zheng 2025-04-25 08:26:05 +01:00
parent 00b19adb50
commit 2b7e63274a
2 changed files with 0 additions and 65 deletions

View File

@ -1,30 +0,0 @@
#!/usr/bin/env python3
from PIL import Image
import os
import argparse
import sys
from roi_lib import *
def parse_args():
parser = argparse.ArgumentParser(description='ROI prediction')
parser.add_argument('--model', type=str, required=True, help='model path')
parser.add_argument('--image', type=str, required=True, help='image file')
return parser.parse_args()
# 主函数
def main():
args = parse_args()
model = load_model(args.model)
image_path = args.image
image_tensor = preprocess_image(image_path)
predicted_class, probabilities = predict(model, image_tensor)
print(f'{image_path} predicted={predicted_class} prob={probabilities}')
if predicted_class == 1:
print("verify ok")
return 0
else:
print("verify ng")
return 1
if __name__ == '__main__':
sys.exit(main())

View File

@ -1,35 +0,0 @@
#!/usr/bin/env python3
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import os
import argparse
import random
def load_model(model_path):
model = models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) # 加载模型权重
model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
return model
def preprocess_image(image_path):
transform = transforms.Compose([
transforms.ToTensor(), # 转换为Tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 归一化
])
image = Image.open(image_path).convert('RGB') # 打开图像并转换为 RGB
image = transform(image).unsqueeze(0) # 增加 batch 维度
return image
def predict(model, image_tensor):
with torch.no_grad(): # 禁用梯度计算
output = model(image_tensor)
_, predicted = torch.max(output, 1) # 获取预测类别
probabilities = torch.nn.functional.softmax(output, dim=1) # 计算概率
return predicted.item(), probabilities.squeeze().tolist()