Copy roi-verify.py to alg/
This commit is contained in:
parent
d5c2be6727
commit
e3d10b7a4c
3
Makefile
3
Makefile
@ -51,6 +51,9 @@ ALG_FILES = \
|
|||||||
$(shell find -L \
|
$(shell find -L \
|
||||||
alg/qrtool \
|
alg/qrtool \
|
||||||
alg/wechat_qrcode \
|
alg/wechat_qrcode \
|
||||||
|
alg/roi_lib.py \
|
||||||
|
alg/roi-verify.py \
|
||||||
|
alg/models/resnet18_20250325_114510_94.56.pth \
|
||||||
) \
|
) \
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
BIN
alg/models/resnet18_20250325_114510_94.56.pth
Normal file
BIN
alg/models/resnet18_20250325_114510_94.56.pth
Normal file
Binary file not shown.
30
alg/roi-verify.py
Executable file
30
alg/roi-verify.py
Executable file
@ -0,0 +1,30 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
from PIL import Image
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
from roi_lib import *
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description='ROI prediction')
|
||||||
|
parser.add_argument('--model', type=str, required=True, help='model path')
|
||||||
|
parser.add_argument('--image', type=str, required=True, help='image file')
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
# 主函数
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
model = load_model(args.model)
|
||||||
|
image_path = args.image
|
||||||
|
image_tensor = preprocess_image(image_path)
|
||||||
|
predicted_class, probabilities = predict(model, image_tensor)
|
||||||
|
print(f'{image_path} predicted={predicted_class} prob={probabilities}')
|
||||||
|
if predicted_class == 1:
|
||||||
|
print("verify ok")
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
print("verify ng")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
sys.exit(main())
|
||||||
35
alg/roi_lib.py
Executable file
35
alg/roi_lib.py
Executable file
@ -0,0 +1,35 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torchvision.models as models
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
from PIL import Image
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import random
|
||||||
|
|
||||||
|
def load_model(model_path):
|
||||||
|
model = models.resnet18(pretrained=True)
|
||||||
|
num_ftrs = model.fc.in_features
|
||||||
|
model.fc = nn.Linear(num_ftrs, 2)
|
||||||
|
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) # 加载模型权重
|
||||||
|
model.eval()
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
model = model.to(device)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def preprocess_image(image_path):
|
||||||
|
transform = transforms.Compose([
|
||||||
|
transforms.ToTensor(), # 转换为Tensor
|
||||||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 归一化
|
||||||
|
])
|
||||||
|
image = Image.open(image_path).convert('RGB') # 打开图像并转换为 RGB
|
||||||
|
image = transform(image).unsqueeze(0) # 增加 batch 维度
|
||||||
|
return image
|
||||||
|
|
||||||
|
def predict(model, image_tensor):
|
||||||
|
with torch.no_grad(): # 禁用梯度计算
|
||||||
|
output = model(image_tensor)
|
||||||
|
_, predicted = torch.max(output, 1) # 获取预测类别
|
||||||
|
probabilities = torch.nn.functional.softmax(output, dim=1) # 计算概率
|
||||||
|
return predicted.item(), probabilities.squeeze().tolist()
|
||||||
@ -687,8 +687,8 @@ class QrVerifyView(BaseView):
|
|||||||
cmd = [qrtool_path, 'side_by_side', img_fn, f.name]
|
cmd = [qrtool_path, 'side_by_side', img_fn, f.name]
|
||||||
messages.append("creating side by side image...")
|
messages.append("creating side by side image...")
|
||||||
subprocess.check_call(cmd, cwd=cwd)
|
subprocess.check_call(cmd, cwd=cwd)
|
||||||
roi_verify_py = os.path.abspath("../research/roi-verify.py")
|
roi_verify_py = os.path.abspath("../alg/roi-verify.py")
|
||||||
roi_verify_model = os.path.abspath("../research/models/resnet18_20250325_114510_94.56.pth")
|
roi_verify_model = os.path.abspath("../alg/models/resnet18_20250325_114510_94.56.pth")
|
||||||
cmd = [roi_verify_py, '--model', roi_verify_model, '--image', side_by_side_fn]
|
cmd = [roi_verify_py, '--model', roi_verify_model, '--image', side_by_side_fn]
|
||||||
messages.append(" ".join(cmd))
|
messages.append(" ".join(cmd))
|
||||||
r = subprocess.call(cmd)
|
r = subprocess.call(cmd)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user