266 lines
11 KiB
Python
Executable File
266 lines
11 KiB
Python
Executable File
#!/usr/bin/env python3
|
||
from flask import Flask, request
|
||
import os
|
||
import argparse
|
||
import numpy as np
|
||
import qr_detection
|
||
from PIL import Image
|
||
from dot_detection import dots_angle_measure_dl_sr, dots_angle_measure_tr_sr
|
||
from thirdTool import qr_box_detect
|
||
from common import cropped_image
|
||
from qr_verify_Tool.qr_orb import ORB_detect, roi_siml
|
||
import json
|
||
from qr_verify_Tool.affinity import roi_affinity_siml
|
||
from qr_verify_Tool.roi_img_process import *
|
||
from qr_verify_Tool.models import *
|
||
from utils import get_model
|
||
from thirdTool import qrsrgan
|
||
|
||
app = Flask(__name__)
|
||
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||
|
||
class Detection(object):
|
||
def get_models(self):
|
||
self.qr_cloud_detect_model = get_model('qr_cloud_detect_20230928.pt')
|
||
self.qr_roi_cloud_detect_model = get_model('qr_roi_cloud_detect_20230928.pt')
|
||
self.qr_net_g_2000_model = get_model('net_g_2000.pth')
|
||
self.roi_net_g_model = get_model("roi_net_g_20240306.pth")
|
||
|
||
def init_models(self):
|
||
device = os.environ.get("CUDA_DEVICE", "cpu")
|
||
self.qr_box_detect_now = qr_box_detect.QR_Box_detect(model_path=self.qr_cloud_detect_model, device=device)
|
||
self.qr_roi_detect_now = qr_box_detect.QR_Box_detect(model_path=self.qr_roi_cloud_detect_model, device=device)
|
||
|
||
self.dot_realesrgan = qrsrgan.RealsrGan(model_path=self.qr_net_g_2000_model, device=device)
|
||
self.roi_generator = GeneratorUNet()
|
||
self.roi_generator.load_state_dict(torch.load(self.roi_net_g_model, map_location=torch.device('cpu')))
|
||
|
||
detection = Detection()
|
||
|
||
@app.route('/upload', methods=['POST', 'GET'])
|
||
def upload():
|
||
f = request.files.get('file')
|
||
upload_path = os.path.join("tmp/tmp." + f.filename.split(".")[-1])
|
||
# secure_filename(f.filename)) #注意:没有的文件夹一定要先创建,不然会提示没有该路径
|
||
f.save(upload_path)
|
||
return upload_path
|
||
|
||
|
||
'''
|
||
#用于防伪图签特征对比
|
||
入参:std_file---备案特征采集区域文件(输入原始roi带黑框图像)
|
||
ter_file---终端特征采集文件(输入仿射扭正后1/4的左上角图像)
|
||
threshold---图像对比阈值
|
||
返回:
|
||
{"comparison_result": 对比结果,'similarity_value':相似度值, "status":'对比方法调用状态'}
|
||
'''
|
||
@app.route('/qr_roi_cloud_comparison', methods=['POST', 'GET'])
|
||
def qr_roi_cloud_comparison():
|
||
def do_api():
|
||
'''
|
||
#获取图片url,此处的url为request参数
|
||
'''
|
||
std_file = request.files.get('std_file') #备案特征采集区域文件
|
||
ter_file = request.files.get('ter_file') #终端特征采集文件
|
||
threshold = float(request.form.get('threshold')) #对比阈值
|
||
|
||
std_roi_img = Image.open(std_file)
|
||
ter_img = Image.open(ter_file)
|
||
|
||
'''备案特征区域数据处理'''
|
||
std_roi_im_fix = roi_img_processing(std_roi_img)
|
||
|
||
#提取特征提取区域
|
||
ter_roi_points, qr_roi_conf = detection.qr_roi_detect_now(ter_img)
|
||
ter_roi_img = cropped_image(ter_img,ter_roi_points,shift=-3,size=1)
|
||
# 图像还原计算
|
||
ter_roi_im_re = detection.roi_generator(ter_roi_img)
|
||
|
||
# roi特征区域处理
|
||
ter_roi_im_re_fix = roi_img_processing(ter_roi_im_re)
|
||
|
||
if ter_roi_im_re_fix is None:
|
||
return {"comparison_result": None, 'similarity_value': None, "status": 'False', "reason": "cannot find roi region after enhancing"}
|
||
|
||
similarity = roi_affinity_siml(ter_roi_im_re_fix,std_roi_im_fix)
|
||
|
||
comparison_result = "pass" if similarity > threshold else "failed"
|
||
return {"comparison_result": comparison_result,'similarity_value':similarity, "status":'OK'}
|
||
|
||
try:
|
||
return {"data": do_api()}
|
||
except Exception as e:
|
||
return {"error": str(e)}
|
||
|
||
'''
|
||
网点角度顶部位置检测
|
||
'''
|
||
@app.route('/dot_detection_top', methods=['POST', 'GET'])
|
||
def dot_detection_top():
|
||
|
||
def do_api():
|
||
|
||
image_file = request.files.get('file')
|
||
threshold = request.form.get('threshold')
|
||
angle = request.form.get('angle')
|
||
img = Image.open(image_file)
|
||
points, conf = detection.qr_box_detect_now(img)
|
||
|
||
def angle_ok(a):
|
||
if a is None:
|
||
return False
|
||
return abs(a - float(angle)) <= float(threshold)
|
||
|
||
if points is None :
|
||
raise Exception("Qr_Decode error")
|
||
else:
|
||
rotation_region = qr_detection.rotation_region_crop(img, points)
|
||
img_ = qr_detection.rotation_waffine(Image.fromarray(rotation_region), img)
|
||
img_ = np.array(img_)
|
||
|
||
x1_1 = int(points[0, 0] * 1)
|
||
y1_1 = int(points[0, 1] * 1 - (points[1, 1] * 1 - points[0, 1] * 1) * (0.25))
|
||
if y1_1 < 0:
|
||
y1_1 = int(points[0, 1] * 1 - (points[1, 1] * 1 - points[0, 1] * 1) * (0.20))
|
||
if y1_1 < 0:
|
||
y1_1 = int(points[0, 1] * 1 - (points[1, 1] * 1 - points[0, 1] * 1) * (0.18))
|
||
if y1_1 < 0:
|
||
y1_1 = int(points[0, 1] * 1 - (points[1, 1] * 1 - points[0, 1] * 1) * (0.15))
|
||
if y1_1 < 0:
|
||
y1_1 = int(points[0, 1] * 1 - (points[1, 1] * 1 - points[0, 1] * 1) * (0.14))
|
||
if y1_1 < 0:
|
||
y1_1 = int(points[0, 1] * 1 - (points[1, 1] * 1 - points[0, 1] * 1) * (0.13))
|
||
if y1_1 < 0:
|
||
y1_1 = int(points[0, 1] * 1 - (points[1, 1] * 1 - points[0, 1] * 1) * (0.12))
|
||
if y1_1 < 0:
|
||
y1_1 = int(points[0, 1] * 1 - (points[1, 1] * 1 - points[0, 1] * 1) * (0.11))
|
||
if y1_1 >= 0:
|
||
x1_2 = int(x1_1 + (points[1, 0] * 1 - points[0, 0] * 1) * 0.5)
|
||
y1_2 = int(y1_1 + (points[1, 1] * 1 - points[0, 1] * 1) * 0.1)
|
||
else:
|
||
return {"status":"False", "reason": "coordinate invalid"}
|
||
|
||
dots_region_top = img_[y1_1:y1_2, x1_1:x1_2]
|
||
|
||
# 采用传统插值方式实现图像增强,进行网线角度检测
|
||
tr_sr_lines_arctan, _ = dots_angle_measure_tr_sr(dots_region_top, save_dots=False,
|
||
res_code=None)
|
||
if angle_ok(tr_sr_lines_arctan):
|
||
return {"dot_angle": tr_sr_lines_arctan, "method": "tr_sr", "status": "OK"}
|
||
|
||
#采用模型生成实现图像增强,进行网线角度检测
|
||
dl_sr_lines_arctan, _ = dots_angle_measure_dl_sr(dots_region_top, detection.dot_realesrgan, save_dots=False,
|
||
res_code=None)
|
||
if dl_sr_lines_arctan is not None:
|
||
if angle_ok(dl_sr_lines_arctan):
|
||
return {"dot_angle": dl_sr_lines_arctan, "method": "dl_sr","status": "OK"}
|
||
else:
|
||
reason = f"Angle {dl_sr_lines_arctan} not within threshold {threshold} from reference angle {angle}"
|
||
return {"dot_angle": dl_sr_lines_arctan, "method": "dl_sr","status": "False", "reason": reason}
|
||
return {"status": "False", "reason": "Failed to find dot angle"}
|
||
try:
|
||
return {"data": do_api()}
|
||
except Exception as e:
|
||
return {"error": str(e)}
|
||
'''
|
||
网点角度底部位置检测
|
||
'''
|
||
@app.route('/dot_detection_bottom', methods=['POST', 'GET'])
|
||
def dot_detection_bottom():
|
||
|
||
def do_api():
|
||
image_file = request.files.get('file')
|
||
threshold = request.form.get('threshold')
|
||
angle = request.form.get('angle')
|
||
img = Image.open(image_file)
|
||
points, conf = detection.qr_box_detect_now(img)
|
||
|
||
def angle_ok(a):
|
||
if a is None:
|
||
return False
|
||
return abs(a - float(angle)) <= float(threshold)
|
||
|
||
if points is None :
|
||
raise Exception("Qr_Decode error")
|
||
else:
|
||
rotation_region = qr_detection.rotation_region_crop(img, points)
|
||
img_ = qr_detection.rotation_waffine(Image.fromarray(rotation_region), img)
|
||
img_ = np.array(img_)
|
||
x2_2 = int(points[1, 0] * 0.8)
|
||
y2_2 = int(points[1, 1] * 1 + (points[1, 1] * 1 - points[0, 1] * 1) * (0.25))
|
||
max_y2 = img_.shape[0]
|
||
if y2_2 > max_y2:
|
||
y2_2 = int(points[1, 1] * 1 + (points[1, 1] * 1 - points[0, 1] * 1) * (0.20))
|
||
|
||
if y2_2 > max_y2:
|
||
y2_2 = int(points[1, 1] * 1 + (points[1, 1] * 1 - points[0, 1] * 1) * (0.18))
|
||
|
||
if y2_2 > max_y2:
|
||
y2_2 = int(points[1, 1] * 1 + (points[1, 1] * 1 - points[0, 1] * 1) * (0.15))
|
||
|
||
if y2_2 > max_y2:
|
||
y2_2 = int(points[1, 1] * 1 + (points[1, 1] * 1 - points[0, 1] * 1) * (0.14))
|
||
|
||
if y2_2 > max_y2:
|
||
y2_2 = int(points[1, 1] * 1 + (points[1, 1] * 1 - points[0, 1] * 1) * (0.13))
|
||
|
||
if y2_2 > max_y2:
|
||
y2_2 = int(points[1, 1] * 1 + (points[1, 1] * 1 - points[0, 1] * 1) * (0.12))
|
||
|
||
if y2_2 > max_y2:
|
||
y2_2 = int(points[1, 1] * 1 + (points[1, 1] * 1 - points[0, 1] * 1) * (0.11))
|
||
|
||
if y2_2 <= max_y2:
|
||
x2_1 = int(x2_2 - (points[1, 0] * 1 - points[0, 0] * 1) * 0.5)
|
||
y2_1 = int(y2_2 - (points[1, 1] * 1 - points[0, 1] * 1) * 0.1)
|
||
else:
|
||
return {"status": "False", "reason": "coordinate invalid"}
|
||
|
||
dots_region_bottom = img_[y2_1:y2_2, x2_1:x2_2]
|
||
|
||
# 采用传统插值方式实现图像增强,进行网线角度检测
|
||
tr_sr_lines_arctan, _ = dots_angle_measure_tr_sr(dots_region_bottom, save_dots=False,
|
||
res_code=None)
|
||
if angle_ok(tr_sr_lines_arctan):
|
||
return {"dot_angle": tr_sr_lines_arctan, "method": "tr_sr", "status": "OK"}
|
||
|
||
#采用模型生成实现图像增强,进行网线角度检测
|
||
dl_sr_lines_arctan, _ = dots_angle_measure_dl_sr(dots_region_bottom, detection.dot_realesrgan, save_dots=False,
|
||
res_code=None)
|
||
if dl_sr_lines_arctan is not None:
|
||
if angle_ok(dl_sr_lines_arctan):
|
||
return {"dot_angle": dl_sr_lines_arctan, "method": "dl_sr","status": "OK"}
|
||
else:
|
||
reason = f"Angle {dl_sr_lines_arctan} not within threshold {threshold} from reference angle {angle}"
|
||
return {"dot_angle": dl_sr_lines_arctan, "method": "dl_sr","status": "False", "reason": reason}
|
||
return {"status": "False", "reason": "Failed to find dot angle"}
|
||
|
||
try:
|
||
return {"data": do_api()}
|
||
except Exception as e:
|
||
return {"error": str(e)}
|
||
|
||
@app.route('/', methods=['GET'])
|
||
def index():
|
||
return 'emblem detection api'
|
||
|
||
def parse_args():
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument("--host", "-l", default="0.0.0.0")
|
||
parser.add_argument("--port", "-p", type=int, default=6006)
|
||
parser.add_argument("--download-models-only", action="store_true")
|
||
return parser.parse_args()
|
||
|
||
def main():
|
||
args = parse_args()
|
||
detection.get_models()
|
||
if args.download_models_only:
|
||
print("Models loaded successfully")
|
||
return
|
||
detection.init_models()
|
||
print(f"Starting server at {args.host}:{args.port}")
|
||
app.run(host=args.host, port=args.port, debug=True)
|
||
|
||
if __name__ == '__main__':
|
||
main() |