2024-12-21 11:12:35 +00:00

250 lines
10 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
import cv2
from flask import Flask, request
import os
import argparse
import numpy as np
import qr_detection
from PIL import Image
from dot_detection import dots_angle_measure_dl_sr, dots_angle_measure_tr_sr
from thirdTool import qr_box_detect
from common import cropped_image
from qr_verify_Tool.qr_orb import ORB_detect, roi_siml
import json
from qr_verify_Tool.affinity import roi_affinity_siml
from qr_verify_Tool.roi_img_process import *
from qr_verify_Tool.models import *
from utils import get_model
app = Flask(__name__)
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from matplotlib.ticker import NullLocator
device = os.environ.get("CUDA_DEVICE", "cpu")
qr_box_detect_now = qr_box_detect.QR_Box_detect(model_path=get_model('qr_cloud_detect_20230928.pt'), device=device)
qr_roi_detect_now = qr_box_detect.QR_Box_detect(model_path=get_model('qr_roi_cloud_detect_20230928.pt'), device=device)
from thirdTool import qrsrgan
dot_realesrgan = qrsrgan.RealsrGan(model_path=get_model('net_g_2000.pth'), device=device)
roi_generator = GeneratorUNet()
roi_generator.load_state_dict(torch.load(get_model("roi_net_g_20240306.pth"), map_location=torch.device('cpu')))
@app.route('/upload', methods=['POST', 'GET'])
def upload():
f = request.files.get('file')
upload_path = os.path.join("tmp/tmp." + f.filename.split(".")[-1])
# secure_filename(f.filename)) #注意:没有的文件夹一定要先创建,不然会提示没有该路径
f.save(upload_path)
return upload_path
'''
#用于防伪图签特征对比
入参std_file---备案特征采集区域文件(输入原始roi带黑框图像)
ter_file---终端特征采集文件(输入仿射扭正后1/4的左上角图像)
threshold---图像对比阈值
返回:
{"comparison_result": 对比结果,'similarity_value':相似度值, "status":'对比方法调用状态'}
'''
@app.route('/qr_roi_cloud_comparison', methods=['POST', 'GET'])
def qr_roi_cloud_comparison():
def do_api():
'''
#获取图片url,此处的url为request参数
'''
std_file = request.files.get('std_file') #备案特征采集区域文件
ter_file = request.files.get('ter_file') #终端特征采集文件
threshold = float(request.form.get('threshold')) #对比阈值
std_roi_img = Image.open(std_file)
ter_img = Image.open(ter_file)
'''备案特征区域数据处理'''
std_roi_im_fix = roi_img_processing(std_roi_img)
#提取特征提取区域
ter_roi_points, qr_roi_conf = qr_roi_detect_now(ter_img)
ter_roi_img = cropped_image(ter_img,ter_roi_points,shift=-3,size=1)
# 图像还原计算
ter_roi_im_re = roi_generator(ter_roi_img)
# roi特征区域处理
ter_roi_im_re_fix = roi_img_processing(ter_roi_im_re)
if ter_roi_im_re_fix is None:
return {"comparison_result": None, 'similarity_value': None, "status": 'False', "reason": "cannot find roi region after enhancing"}
similarity = roi_affinity_siml(ter_roi_im_re_fix,std_roi_im_fix)
comparison_result = "pass" if similarity > threshold else "failed"
return {"comparison_result": comparison_result,'similarity_value':similarity, "status":'OK'}
try:
return {"data": do_api()}
except Exception as e:
return {"error": str(e)}
'''
网点角度顶部位置检测
'''
@app.route('/dot_detection_top', methods=['POST', 'GET'])
def dot_detection_top():
def do_api():
image_file = request.files.get('file')
threshold = request.form.get('threshold')
angle = request.form.get('angle')
img = Image.open(image_file)
points, conf = qr_box_detect_now(img)
def angle_ok(a):
if a is None:
return False
return abs(a - float(angle)) <= float(threshold)
if points is None :
raise Exception("Qr_Decode error")
else:
rotation_region = qr_detection.rotation_region_crop(img, points)
img_ = qr_detection.rotation_waffine(Image.fromarray(rotation_region), img)
img_ = np.array(img_)
x1_1 = int(points[0, 0] * 1)
y1_1 = int(points[0, 1] * 1 - (points[1, 1] * 1 - points[0, 1] * 1) * (0.25))
if y1_1 < 0:
y1_1 = int(points[0, 1] * 1 - (points[1, 1] * 1 - points[0, 1] * 1) * (0.20))
if y1_1 < 0:
y1_1 = int(points[0, 1] * 1 - (points[1, 1] * 1 - points[0, 1] * 1) * (0.18))
if y1_1 < 0:
y1_1 = int(points[0, 1] * 1 - (points[1, 1] * 1 - points[0, 1] * 1) * (0.15))
if y1_1 < 0:
y1_1 = int(points[0, 1] * 1 - (points[1, 1] * 1 - points[0, 1] * 1) * (0.14))
if y1_1 < 0:
y1_1 = int(points[0, 1] * 1 - (points[1, 1] * 1 - points[0, 1] * 1) * (0.13))
if y1_1 < 0:
y1_1 = int(points[0, 1] * 1 - (points[1, 1] * 1 - points[0, 1] * 1) * (0.12))
if y1_1 < 0:
y1_1 = int(points[0, 1] * 1 - (points[1, 1] * 1 - points[0, 1] * 1) * (0.11))
if y1_1 >= 0:
x1_2 = int(x1_1 + (points[1, 0] * 1 - points[0, 0] * 1) * 0.5)
y1_2 = int(y1_1 + (points[1, 1] * 1 - points[0, 1] * 1) * 0.1)
else:
return {"status":"False", "reason": "coordinate invalid"}
dots_region_top = img_[y1_1:y1_2, x1_1:x1_2]
# 采用传统插值方式实现图像增强,进行网线角度检测
tr_sr_lines_arctan, _ = dots_angle_measure_tr_sr(dots_region_top, save_dots=False,
res_code=None)
if angle_ok(tr_sr_lines_arctan):
return {"dot_angle": tr_sr_lines_arctan, "method": "tr_sr", "status": "OK"}
#采用模型生成实现图像增强,进行网线角度检测
dl_sr_lines_arctan, _ = dots_angle_measure_dl_sr(dots_region_top, dot_realesrgan, save_dots=False,
res_code=None)
if dl_sr_lines_arctan is not None:
if angle_ok(dl_sr_lines_arctan):
return {"dot_angle": dl_sr_lines_arctan, "method": "dl_sr","status": "OK"}
else:
reason = f"Angle {dl_sr_lines_arctan} not within threshold {threshold} from reference angle {angle}"
return {"dot_angle": dl_sr_lines_arctan, "method": "dl_sr","status": "False", "reason": reason}
return {"status": "False", "reason": "Failed to find dot angle"}
try:
return {"data": do_api()}
except Exception as e:
return {"error": str(e)}
'''
网点角度底部位置检测
'''
@app.route('/dot_detection_bottom', methods=['POST', 'GET'])
def dot_detection_bottom():
def do_api():
image_file = request.files.get('file')
threshold = request.form.get('threshold')
angle = request.form.get('angle')
img = Image.open(image_file)
points, conf = qr_box_detect_now(img)
def angle_ok(a):
if a is None:
return False
return abs(a - float(angle)) <= float(threshold)
if points is None :
raise Exception("Qr_Decode error")
else:
rotation_region = qr_detection.rotation_region_crop(img, points)
img_ = qr_detection.rotation_waffine(Image.fromarray(rotation_region), img)
img_ = np.array(img_)
x2_2 = int(points[1, 0] * 0.8)
y2_2 = int(points[1, 1] * 1 + (points[1, 1] * 1 - points[0, 1] * 1) * (0.25))
max_y2 = img_.shape[0]
if y2_2 > max_y2:
y2_2 = int(points[1, 1] * 1 + (points[1, 1] * 1 - points[0, 1] * 1) * (0.20))
if y2_2 > max_y2:
y2_2 = int(points[1, 1] * 1 + (points[1, 1] * 1 - points[0, 1] * 1) * (0.18))
if y2_2 > max_y2:
y2_2 = int(points[1, 1] * 1 + (points[1, 1] * 1 - points[0, 1] * 1) * (0.15))
if y2_2 > max_y2:
y2_2 = int(points[1, 1] * 1 + (points[1, 1] * 1 - points[0, 1] * 1) * (0.14))
if y2_2 > max_y2:
y2_2 = int(points[1, 1] * 1 + (points[1, 1] * 1 - points[0, 1] * 1) * (0.13))
if y2_2 > max_y2:
y2_2 = int(points[1, 1] * 1 + (points[1, 1] * 1 - points[0, 1] * 1) * (0.12))
if y2_2 > max_y2:
y2_2 = int(points[1, 1] * 1 + (points[1, 1] * 1 - points[0, 1] * 1) * (0.11))
if y2_2 <= max_y2:
x2_1 = int(x2_2 - (points[1, 0] * 1 - points[0, 0] * 1) * 0.5)
y2_1 = int(y2_2 - (points[1, 1] * 1 - points[0, 1] * 1) * 0.1)
else:
return {"status": "False", "reason": "coordinate invalid"}
dots_region_bottom = img_[y2_1:y2_2, x2_1:x2_2]
# 采用传统插值方式实现图像增强,进行网线角度检测
tr_sr_lines_arctan, _ = dots_angle_measure_tr_sr(dots_region_bottom, save_dots=False,
res_code=None)
if angle_ok(tr_sr_lines_arctan):
return {"dot_angle": tr_sr_lines_arctan, "method": "tr_sr", "status": "OK"}
#采用模型生成实现图像增强,进行网线角度检测
dl_sr_lines_arctan, _ = dots_angle_measure_dl_sr(dots_region_bottom, dot_realesrgan, save_dots=False,
res_code=None)
if dl_sr_lines_arctan is not None:
if angle_ok(dl_sr_lines_arctan):
return {"dot_angle": dl_sr_lines_arctan, "method": "dl_sr","status": "OK"}
else:
reason = f"Angle {dl_sr_lines_arctan} not within threshold {threshold} from reference angle {angle}"
return {"dot_angle": dl_sr_lines_arctan, "method": "dl_sr","status": "False", "reason": reason}
return {"status": "False", "reason": "Failed to find dot angle"}
try:
return {"data": do_api()}
except Exception as e:
return {"error": str(e)}
@app.route('/', methods=['GET'])
def index():
return 'emblem detection api'
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--host", "-l", default="0.0.0.0")
parser.add_argument("--port", "-p", type=int, default=6006)
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
print("starting detection server on %s:%d" % (args.host, args.port))
app.run(host=args.host, port=args.port, debug=True)