diff --git a/Makefile b/Makefile index 4ca8757..85cd0e3 100644 --- a/Makefile +++ b/Makefile @@ -51,6 +51,9 @@ ALG_FILES = \ $(shell find -L \ alg/qrtool \ alg/wechat_qrcode \ + alg/roi_lib.py \ + alg/roi-verify.py \ + alg/models/resnet18_20250325_114510_94.56.pth \ ) \ ) diff --git a/alg/models/resnet18_20250325_114510_94.56.pth b/alg/models/resnet18_20250325_114510_94.56.pth new file mode 100644 index 0000000..5835930 Binary files /dev/null and b/alg/models/resnet18_20250325_114510_94.56.pth differ diff --git a/alg/roi-verify.py b/alg/roi-verify.py new file mode 100755 index 0000000..19384f6 --- /dev/null +++ b/alg/roi-verify.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 +from PIL import Image +import os +import argparse +import sys +from roi_lib import * + +def parse_args(): + parser = argparse.ArgumentParser(description='ROI prediction') + parser.add_argument('--model', type=str, required=True, help='model path') + parser.add_argument('--image', type=str, required=True, help='image file') + return parser.parse_args() + +# 主函数 +def main(): + args = parse_args() + model = load_model(args.model) + image_path = args.image + image_tensor = preprocess_image(image_path) + predicted_class, probabilities = predict(model, image_tensor) + print(f'{image_path} predicted={predicted_class} prob={probabilities}') + if predicted_class == 1: + print("verify ok") + return 0 + else: + print("verify ng") + return 1 + +if __name__ == '__main__': + sys.exit(main()) \ No newline at end of file diff --git a/alg/roi_lib.py b/alg/roi_lib.py new file mode 100755 index 0000000..d360d8a --- /dev/null +++ b/alg/roi_lib.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python3 +import torch +import torch.nn as nn +import torchvision.models as models +import torchvision.transforms as transforms +from PIL import Image +import os +import argparse +import random + +def load_model(model_path): + model = models.resnet18(pretrained=True) + num_ftrs = model.fc.in_features + model.fc = nn.Linear(num_ftrs, 2) + model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) # 加载模型权重 + model.eval() + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model = model.to(device) + return model + +def preprocess_image(image_path): + transform = transforms.Compose([ + transforms.ToTensor(), # 转换为Tensor + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 归一化 + ]) + image = Image.open(image_path).convert('RGB') # 打开图像并转换为 RGB + image = transform(image).unsqueeze(0) # 增加 batch 维度 + return image + +def predict(model, image_tensor): + with torch.no_grad(): # 禁用梯度计算 + output = model(image_tensor) + _, predicted = torch.max(output, 1) # 获取预测类别 + probabilities = torch.nn.functional.softmax(output, dim=1) # 计算概率 + return predicted.item(), probabilities.squeeze().tolist() \ No newline at end of file diff --git a/api/products/views.py b/api/products/views.py index 5b39b83..511990b 100644 --- a/api/products/views.py +++ b/api/products/views.py @@ -687,8 +687,8 @@ class QrVerifyView(BaseView): cmd = [qrtool_path, 'side_by_side', img_fn, f.name] messages.append("creating side by side image...") subprocess.check_call(cmd, cwd=cwd) - roi_verify_py = os.path.abspath("../research/roi-verify.py") - roi_verify_model = os.path.abspath("../research/models/resnet18_20250325_114510_94.56.pth") + roi_verify_py = os.path.abspath("../alg/roi-verify.py") + roi_verify_model = os.path.abspath("../alg/models/resnet18_20250325_114510_94.56.pth") cmd = [roi_verify_py, '--model', roi_verify_model, '--image', side_by_side_fn] messages.append(" ".join(cmd)) r = subprocess.call(cmd)